Refactor s3select to support parquet. (#7023)

Also handle pretty formatted JSON documents.
This commit is contained in:
Bala FA 2019-01-09 06:23:04 +05:30 committed by kannappanr
parent e98d89274f
commit b0deea27df
124 changed files with 27376 additions and 4152 deletions

View File

@ -16,22 +16,24 @@ matrix:
sudo: required sudo: required
env: env:
- ARCH=x86_64 - ARCH=x86_64
- CGO_ENABLED=0
go: 1.11.4 go: 1.11.4
script: script:
- make - make
- diff -au <(gofmt -s -d cmd) <(printf "") - diff -au <(gofmt -s -d cmd) <(printf "")
- diff -au <(gofmt -s -d pkg) <(printf "") - diff -au <(gofmt -s -d pkg) <(printf "")
- for d in $(go list ./... | grep -v browser); do go test -v -race --timeout 15m "$d"; done - for d in $(go list ./... | grep -v browser); do CGO_ENABLED=1 go test -v -race --timeout 15m "$d"; done
- make verify - make verify
- make coverage - make coverage
- cd browser && yarn && yarn test && cd .. - cd browser && yarn && yarn test && cd ..
- os: windows - os: windows
env: env:
- ARCH=x86_64 - ARCH=x86_64
- CGO_ENABLED=0
go: 1.11.4 go: 1.11.4
script: script:
- go build --ldflags="$(go run buildscripts/gen-ldflags.go)" -o %GOPATH%\bin\minio.exe - go build --ldflags="$(go run buildscripts/gen-ldflags.go)" -o %GOPATH%\bin\minio.exe
- for d in $(go list ./... | grep -v browser); do go test -v -race --timeout 20m "$d"; done - for d in $(go list ./... | grep -v browser); do CGO_ENABLED=1 go test -v -race --timeout 20m "$d"; done
- bash buildscripts/go-coverage.sh - bash buildscripts/go-coverage.sh
before_script: before_script:

View File

@ -60,7 +60,7 @@ spelling:
check: test check: test
test: verifiers build test: verifiers build
@echo "Running unit tests" @echo "Running unit tests"
@go test -tags kqueue ./... @CGO_ENABLED=0 go test -tags kqueue ./...
verify: build verify: build
@echo "Verifying build" @echo "Verifying build"

View File

@ -4,7 +4,7 @@ set -e
echo "" > coverage.txt echo "" > coverage.txt
for d in $(go list ./... | grep -v browser); do for d in $(go list ./... | grep -v browser); do
go test -v -coverprofile=profile.out -covermode=atomic "$d" CGO_ENABLED=0 go test -v -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then if [ -f profile.out ]; then
cat profile.out >> coverage.txt cat profile.out >> coverage.txt
rm profile.out rm profile.out

View File

@ -28,8 +28,6 @@ import (
"github.com/minio/minio/pkg/dns" "github.com/minio/minio/pkg/dns"
"github.com/minio/minio/pkg/event" "github.com/minio/minio/pkg/event"
"github.com/minio/minio/pkg/hash" "github.com/minio/minio/pkg/hash"
"github.com/minio/minio/pkg/s3select"
"github.com/minio/minio/pkg/s3select/format"
) )
// APIError structure // APIError structure
@ -1512,168 +1510,6 @@ func toAPIErrorCode(ctx context.Context, err error) (apiErr APIErrorCode) {
case errOperationTimedOut, context.Canceled, context.DeadlineExceeded: case errOperationTimedOut, context.Canceled, context.DeadlineExceeded:
apiErr = ErrOperationTimedOut apiErr = ErrOperationTimedOut
} }
switch err {
case s3select.ErrBusy:
apiErr = ErrBusy
case s3select.ErrUnauthorizedAccess:
apiErr = ErrUnauthorizedAccess
case s3select.ErrExpressionTooLong:
apiErr = ErrExpressionTooLong
case s3select.ErrIllegalSQLFunctionArgument:
apiErr = ErrIllegalSQLFunctionArgument
case s3select.ErrInvalidKeyPath:
apiErr = ErrInvalidKeyPath
case s3select.ErrInvalidCompressionFormat:
apiErr = ErrInvalidCompressionFormat
case s3select.ErrInvalidFileHeaderInfo:
apiErr = ErrInvalidFileHeaderInfo
case s3select.ErrInvalidJSONType:
apiErr = ErrInvalidJSONType
case s3select.ErrInvalidQuoteFields:
apiErr = ErrInvalidQuoteFields
case s3select.ErrInvalidRequestParameter:
apiErr = ErrInvalidRequestParameter
case s3select.ErrInvalidDataType:
apiErr = ErrInvalidDataType
case s3select.ErrInvalidTextEncoding:
apiErr = ErrInvalidTextEncoding
case s3select.ErrInvalidTableAlias:
apiErr = ErrInvalidTableAlias
case s3select.ErrMissingRequiredParameter:
apiErr = ErrMissingRequiredParameter
case s3select.ErrObjectSerializationConflict:
apiErr = ErrObjectSerializationConflict
case s3select.ErrUnsupportedSQLOperation:
apiErr = ErrUnsupportedSQLOperation
case s3select.ErrUnsupportedSQLStructure:
apiErr = ErrUnsupportedSQLStructure
case s3select.ErrUnsupportedSyntax:
apiErr = ErrUnsupportedSyntax
case s3select.ErrUnsupportedRangeHeader:
apiErr = ErrUnsupportedRangeHeader
case s3select.ErrLexerInvalidChar:
apiErr = ErrLexerInvalidChar
case s3select.ErrLexerInvalidOperator:
apiErr = ErrLexerInvalidOperator
case s3select.ErrLexerInvalidLiteral:
apiErr = ErrLexerInvalidLiteral
case s3select.ErrLexerInvalidIONLiteral:
apiErr = ErrLexerInvalidIONLiteral
case s3select.ErrParseExpectedDatePart:
apiErr = ErrParseExpectedDatePart
case s3select.ErrParseExpectedKeyword:
apiErr = ErrParseExpectedKeyword
case s3select.ErrParseExpectedTokenType:
apiErr = ErrParseExpectedTokenType
case s3select.ErrParseExpected2TokenTypes:
apiErr = ErrParseExpected2TokenTypes
case s3select.ErrParseExpectedNumber:
apiErr = ErrParseExpectedNumber
case s3select.ErrParseExpectedRightParenBuiltinFunctionCall:
apiErr = ErrParseExpectedRightParenBuiltinFunctionCall
case s3select.ErrParseExpectedTypeName:
apiErr = ErrParseExpectedTypeName
case s3select.ErrParseExpectedWhenClause:
apiErr = ErrParseExpectedWhenClause
case s3select.ErrParseUnsupportedToken:
apiErr = ErrParseUnsupportedToken
case s3select.ErrParseUnsupportedLiteralsGroupBy:
apiErr = ErrParseUnsupportedLiteralsGroupBy
case s3select.ErrParseExpectedMember:
apiErr = ErrParseExpectedMember
case s3select.ErrParseUnsupportedSelect:
apiErr = ErrParseUnsupportedSelect
case s3select.ErrParseUnsupportedCase:
apiErr = ErrParseUnsupportedCase
case s3select.ErrParseUnsupportedCaseClause:
apiErr = ErrParseUnsupportedCaseClause
case s3select.ErrParseUnsupportedAlias:
apiErr = ErrParseUnsupportedAlias
case s3select.ErrParseUnsupportedSyntax:
apiErr = ErrParseUnsupportedSyntax
case s3select.ErrParseUnknownOperator:
apiErr = ErrParseUnknownOperator
case s3select.ErrParseMissingIdentAfterAt:
apiErr = ErrParseMissingIdentAfterAt
case s3select.ErrParseUnexpectedOperator:
apiErr = ErrParseUnexpectedOperator
case s3select.ErrParseUnexpectedTerm:
apiErr = ErrParseUnexpectedTerm
case s3select.ErrParseUnexpectedToken:
apiErr = ErrParseUnexpectedToken
case s3select.ErrParseUnexpectedKeyword:
apiErr = ErrParseUnexpectedKeyword
case s3select.ErrParseExpectedExpression:
apiErr = ErrParseExpectedExpression
case s3select.ErrParseExpectedLeftParenAfterCast:
apiErr = ErrParseExpectedLeftParenAfterCast
case s3select.ErrParseExpectedLeftParenValueConstructor:
apiErr = ErrParseExpectedLeftParenValueConstructor
case s3select.ErrParseExpectedLeftParenBuiltinFunctionCall:
apiErr = ErrParseExpectedLeftParenBuiltinFunctionCall
case s3select.ErrParseExpectedArgumentDelimiter:
apiErr = ErrParseExpectedArgumentDelimiter
case s3select.ErrParseCastArity:
apiErr = ErrParseCastArity
case s3select.ErrParseInvalidTypeParam:
apiErr = ErrParseInvalidTypeParam
case s3select.ErrParseEmptySelect:
apiErr = ErrParseEmptySelect
case s3select.ErrParseSelectMissingFrom:
apiErr = ErrParseSelectMissingFrom
case s3select.ErrParseExpectedIdentForGroupName:
apiErr = ErrParseExpectedIdentForGroupName
case s3select.ErrParseExpectedIdentForAlias:
apiErr = ErrParseExpectedIdentForAlias
case s3select.ErrParseUnsupportedCallWithStar:
apiErr = ErrParseUnsupportedCallWithStar
case s3select.ErrParseNonUnaryAgregateFunctionCall:
apiErr = ErrParseNonUnaryAgregateFunctionCall
case s3select.ErrParseMalformedJoin:
apiErr = ErrParseMalformedJoin
case s3select.ErrParseExpectedIdentForAt:
apiErr = ErrParseExpectedIdentForAt
case s3select.ErrParseAsteriskIsNotAloneInSelectList:
apiErr = ErrParseAsteriskIsNotAloneInSelectList
case s3select.ErrParseCannotMixSqbAndWildcardInSelectList:
apiErr = ErrParseCannotMixSqbAndWildcardInSelectList
case s3select.ErrParseInvalidContextForWildcardInSelectList:
apiErr = ErrParseInvalidContextForWildcardInSelectList
case s3select.ErrIncorrectSQLFunctionArgumentType:
apiErr = ErrIncorrectSQLFunctionArgumentType
case s3select.ErrValueParseFailure:
apiErr = ErrValueParseFailure
case s3select.ErrIntegerOverflow:
apiErr = ErrIntegerOverflow
case s3select.ErrLikeInvalidInputs:
apiErr = ErrLikeInvalidInputs
case s3select.ErrCastFailed:
apiErr = ErrCastFailed
case s3select.ErrInvalidCast:
apiErr = ErrInvalidCast
case s3select.ErrEvaluatorInvalidTimestampFormatPattern:
apiErr = ErrEvaluatorInvalidTimestampFormatPattern
case s3select.ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing:
apiErr = ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing
case s3select.ErrEvaluatorTimestampFormatPatternDuplicateFields:
apiErr = ErrEvaluatorTimestampFormatPatternDuplicateFields
case s3select.ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch:
apiErr = ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch
case s3select.ErrEvaluatorUnterminatedTimestampFormatPatternToken:
apiErr = ErrEvaluatorUnterminatedTimestampFormatPatternToken
case s3select.ErrEvaluatorInvalidTimestampFormatPatternToken:
apiErr = ErrEvaluatorInvalidTimestampFormatPatternToken
case s3select.ErrEvaluatorInvalidTimestampFormatPatternSymbol:
apiErr = ErrEvaluatorInvalidTimestampFormatPatternSymbol
case s3select.ErrEvaluatorBindingDoesNotExist:
apiErr = ErrEvaluatorBindingDoesNotExist
case s3select.ErrMissingHeaders:
apiErr = ErrMissingHeaders
case format.ErrParseInvalidPathComponent:
apiErr = ErrMissingHeaders
case format.ErrInvalidColumnIndex:
apiErr = ErrInvalidColumnIndex
}
// Compression errors // Compression errors
switch err { switch err {

View File

@ -33,7 +33,6 @@ import (
snappy "github.com/golang/snappy" snappy "github.com/golang/snappy"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/klauspost/readahead"
miniogo "github.com/minio/minio-go" miniogo "github.com/minio/minio-go"
"github.com/minio/minio-go/pkg/encrypt" "github.com/minio/minio-go/pkg/encrypt"
"github.com/minio/minio/cmd/crypto" "github.com/minio/minio/cmd/crypto"
@ -106,6 +105,12 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
writeErrorResponseHeadersOnly(w, toAPIErrorCode(ctx, err)) writeErrorResponseHeadersOnly(w, toAPIErrorCode(ctx, err))
return return
} }
getObjectInfo := objectAPI.GetObjectInfo
if api.CacheAPI() != nil {
getObjectInfo = api.CacheAPI().GetObjectInfo
}
// Check for auth type to return S3 compatible error. // Check for auth type to return S3 compatible error.
// type to return the correct error (NoSuchKey vs AccessDenied) // type to return the correct error (NoSuchKey vs AccessDenied)
if s3Error := checkRequestAuthType(ctx, r, policy.GetObjectAction, bucket, object); s3Error != ErrNone { if s3Error := checkRequestAuthType(ctx, r, policy.GetObjectAction, bucket, object); s3Error != ErrNone {
@ -129,11 +134,6 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
ConditionValues: getConditionValues(r, ""), ConditionValues: getConditionValues(r, ""),
IsOwner: false, IsOwner: false,
}) { }) {
getObjectInfo := objectAPI.GetObjectInfo
if api.CacheAPI() != nil {
getObjectInfo = api.CacheAPI().GetObjectInfo
}
_, err = getObjectInfo(ctx, bucket, object, opts) _, err = getObjectInfo(ctx, bucket, object, opts)
if toAPIErrorCode(ctx, err) == ErrNoSuchKey { if toAPIErrorCode(ctx, err) == ErrNoSuchKey {
s3Error = ErrNoSuchKey s3Error = ErrNoSuchKey
@ -156,18 +156,14 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
return return
} }
var selectReq s3select.ObjectSelectRequest s3Select, err := s3select.NewS3Select(r.Body)
if err := xmlDecoder(r.Body, &selectReq, r.ContentLength); err != nil { if err != nil {
writeErrorResponse(w, ErrMalformedXML, r.URL, guessIsBrowserReq(r)) if serr, ok := err.(s3select.SelectError); ok {
return w.WriteHeader(serr.HTTPStatusCode())
} w.Write(s3select.NewErrorMessage(serr.ErrorCode(), serr.ErrorMessage()))
} else {
if !strings.EqualFold(string(selectReq.ExpressionType), "SQL") { writeErrorResponse(w, ErrInternalError, r.URL, guessIsBrowserReq(r))
writeErrorResponse(w, ErrInvalidExpressionType, r.URL, guessIsBrowserReq(r)) }
return
}
if len(selectReq.Expression) >= s3select.MaxExpressionLength {
writeErrorResponse(w, ErrExpressionTooLong, r.URL, guessIsBrowserReq(r))
return return
} }
@ -175,123 +171,38 @@ func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r
if api.CacheAPI() != nil { if api.CacheAPI() != nil {
getObjectNInfo = api.CacheAPI().GetObjectNInfo getObjectNInfo = api.CacheAPI().GetObjectNInfo
} }
getObject := func(offset, length int64) (rc io.ReadCloser, err error) {
isSuffixLength := false
if offset < 0 {
isSuffixLength = true
}
rs := &HTTPRangeSpec{
IsSuffixLength: isSuffixLength,
Start: offset,
End: length,
}
gr, err := getObjectNInfo(ctx, bucket, object, nil, r.Header, readLock, opts) return getObjectNInfo(ctx, bucket, object, rs, r.Header, readLock, ObjectOptions{})
}
if err = s3Select.Open(getObject); err != nil {
if serr, ok := err.(s3select.SelectError); ok {
w.WriteHeader(serr.HTTPStatusCode())
w.Write(s3select.NewErrorMessage(serr.ErrorCode(), serr.ErrorMessage()))
} else {
writeErrorResponse(w, ErrInternalError, r.URL, guessIsBrowserReq(r))
}
return
}
s3Select.Evaluate(w)
s3Select.Close()
objInfo, err := getObjectInfo(ctx, bucket, object, opts)
if err != nil { if err != nil {
writeErrorResponse(w, toAPIErrorCode(ctx, err), r.URL, guessIsBrowserReq(r)) logger.LogIf(ctx, err)
return return
} }
defer gr.Close()
objInfo := gr.ObjInfo
if selectReq.InputSerialization.CompressionType == s3select.SelectCompressionGZIP {
if !strings.Contains(objInfo.ContentType, "gzip") {
writeErrorResponse(w, ErrInvalidDataSource, r.URL, guessIsBrowserReq(r))
return
}
}
if selectReq.InputSerialization.CompressionType == s3select.SelectCompressionBZIP {
if !strings.Contains(objInfo.ContentType, "bzip") {
writeErrorResponse(w, ErrInvalidDataSource, r.URL, guessIsBrowserReq(r))
return
}
}
if selectReq.InputSerialization.CompressionType == "" {
selectReq.InputSerialization.CompressionType = s3select.SelectCompressionNONE
if !strings.Contains(objInfo.ContentType, "text/csv") && !strings.Contains(objInfo.ContentType, "application/json") {
writeErrorResponse(w, ErrInvalidDataSource, r.URL, guessIsBrowserReq(r))
return
}
}
if !strings.EqualFold(string(selectReq.ExpressionType), "SQL") {
writeErrorResponse(w, ErrInvalidExpressionType, r.URL, guessIsBrowserReq(r))
return
}
if len(selectReq.Expression) >= s3select.MaxExpressionLength {
writeErrorResponse(w, ErrExpressionTooLong, r.URL, guessIsBrowserReq(r))
return
}
if selectReq.InputSerialization.CSV == nil && selectReq.InputSerialization.JSON == nil {
writeErrorResponse(w, ErrInvalidRequestParameter, r.URL, guessIsBrowserReq(r))
return
}
if selectReq.OutputSerialization.CSV == nil && selectReq.OutputSerialization.JSON == nil {
writeErrorResponse(w, ErrInvalidRequestParameter, r.URL, guessIsBrowserReq(r))
return
}
if selectReq.InputSerialization.CSV != nil {
if selectReq.InputSerialization.CSV.FileHeaderInfo != s3select.CSVFileHeaderInfoUse &&
selectReq.InputSerialization.CSV.FileHeaderInfo != s3select.CSVFileHeaderInfoNone &&
selectReq.InputSerialization.CSV.FileHeaderInfo != s3select.CSVFileHeaderInfoIgnore &&
selectReq.InputSerialization.CSV.FileHeaderInfo != "" {
writeErrorResponse(w, ErrInvalidFileHeaderInfo, r.URL, guessIsBrowserReq(r))
return
}
if selectReq.OutputSerialization.CSV != nil {
if selectReq.OutputSerialization.CSV.QuoteFields != s3select.CSVQuoteFieldsAlways &&
selectReq.OutputSerialization.CSV.QuoteFields != s3select.CSVQuoteFieldsAsNeeded &&
selectReq.OutputSerialization.CSV.QuoteFields != "" {
writeErrorResponse(w, ErrInvalidQuoteFields, r.URL, guessIsBrowserReq(r))
return
}
}
if len(selectReq.InputSerialization.CSV.RecordDelimiter) > 2 {
writeErrorResponse(w, ErrInvalidRequestParameter, r.URL, guessIsBrowserReq(r))
return
}
}
if selectReq.InputSerialization.JSON != nil {
if selectReq.InputSerialization.JSON.Type != s3select.JSONLinesType {
writeErrorResponse(w, ErrInvalidJSONType, r.URL, guessIsBrowserReq(r))
return
}
}
// Set encryption response headers
if objectAPI.IsEncryptionSupported() {
objInfo.UserDefined = CleanMinioInternalMetadataKeys(objInfo.UserDefined)
if crypto.IsEncrypted(objInfo.UserDefined) {
switch {
case crypto.S3.IsEncrypted(objInfo.UserDefined):
w.Header().Set(crypto.SSEHeader, crypto.SSEAlgorithmAES256)
case crypto.SSEC.IsEncrypted(objInfo.UserDefined):
w.Header().Set(crypto.SSECAlgorithm, r.Header.Get(crypto.SSECAlgorithm))
w.Header().Set(crypto.SSECKeyMD5, r.Header.Get(crypto.SSECKeyMD5))
}
}
}
reader := readahead.NewReader(gr)
defer reader.Close()
size := objInfo.Size
if objInfo.IsCompressed() {
size = objInfo.GetActualSize()
if size < 0 {
writeErrorResponse(w, toAPIErrorCode(ctx, errInvalidDecompressedSize), r.URL, guessIsBrowserReq(r))
return
}
}
s3s, err := s3select.New(reader, size, selectReq)
if err != nil {
writeErrorResponse(w, toAPIErrorCode(ctx, err), r.URL, guessIsBrowserReq(r))
return
}
// Parses the select query and checks for an error
_, _, _, _, _, _, err = s3select.ParseSelect(s3s)
if err != nil {
writeErrorResponse(w, toAPIErrorCode(ctx, err), r.URL, guessIsBrowserReq(r))
return
}
// Executes the query on data-set
s3select.Execute(w, s3s)
// Get host and port from Request.RemoteAddr. // Get host and port from Request.RemoteAddr.
host, port, err := net.SplitHostPort(handlers.GetSourceIP(r)) host, port, err := net.SplitHostPort(handlers.GetSourceIP(r))

View File

@ -1,87 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ioutil
import (
"bufio"
"io"
)
var (
nByte byte = 10 // the byte that corresponds to the '\n' rune.
rByte byte = 13 // the byte that corresponds to the '\r' rune.
)
// DelimitedReader reduces the custom delimiter to `\n`.
type DelimitedReader struct {
r *bufio.Reader
delimiter []rune // Select can have upto 2 characters as delimiter.
assignEmpty bool // Decides whether the next read byte should be discarded.
}
// NewDelimitedReader detects the custom delimiter and replaces with `\n`.
func NewDelimitedReader(r io.Reader, delimiter []rune) *DelimitedReader {
return &DelimitedReader{r: bufio.NewReader(r), delimiter: delimiter, assignEmpty: false}
}
// Reads and replaces the custom delimiter with `\n`.
func (r *DelimitedReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if err != nil {
return
}
for i, b := range p {
if r.assignEmpty {
swapAndNullify(p, i)
r.assignEmpty = false
continue
}
if b == rByte && rune(b) != r.delimiter[0] {
// Replace the carriage returns with `\n`.
// Mac styled csv will have `\r` as their record delimiter.
p[i] = nByte
} else if rune(b) == r.delimiter[0] { // Eg, `\r\n`,`ab`,`a` are valid delimiters
if i+1 == len(p) && len(r.delimiter) > 1 {
// If the first delimiter match falls on the boundary,
// Peek the next byte and if it matches, discard it in the next byte read.
if nextByte, nerr := r.r.Peek(1); nerr == nil {
if rune(nextByte[0]) == r.delimiter[1] {
p[i] = nByte
// To Discard in the next read.
r.assignEmpty = true
}
}
} else if len(r.delimiter) > 1 && rune(p[i+1]) == r.delimiter[1] {
// The second delimiter falls in the same chunk.
p[i] = nByte
r.assignEmpty = true
} else if len(r.delimiter) == 1 {
// Replace with `\n` incase of single charecter delimiter match.
p[i] = nByte
}
}
}
return
}
// Occupy the first byte space and nullify the last byte.
func swapAndNullify(p []byte, n int) {
for i := n; i < len(p)-1; i++ {
p[i] = p[i+1]
}
p[len(p)-1] = 0
}

View File

@ -1,83 +0,0 @@
/*
* Minio Cloud Storage, (C) 2016, 2017, 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ioutil
import (
"bytes"
"io"
"strings"
"testing"
)
// Test for DelimitedCSVReader.
func TestDelimitedReader(t *testing.T) {
expected := "username,age\nbanana,12\ncarrot,23\napple,34\nbrinjal,90\nraddish,45"
inputs := []struct {
inputcsv string
delimiter string
chunkSize int
}{
// case 1 - with default `\n` delimiter.
{"username,age\nbanana,12\ncarrot,23\napple,34\nbrinjal,90\nraddish,45", "\n", 10},
// case 2 - with carriage return `\r` which should be replaced with `\n` by default.
{"username,age\rbanana,12\rcarrot,23\rapple,34\rbrinjal,90\rraddish,45", "\n", 10},
// case 3 - with a double character delimiter (octals).
{"username,age\r\nbanana,12\r\ncarrot,23\r\napple,34\r\nbrinjal,90\r\nraddish,45", "\r\n", 10},
// case 4 - with a double character delimiter.
{"username,agexvbanana,12xvcarrot,23xvapple,34xvbrinjal,90xvraddish,45", "xv", 10},
// case 5 - with a double character delimiter `\t `
{"username,age\t banana,12\t carrot,23\t apple,34\t brinjal,90\t raddish,45", "\t ", 10},
// case 6 - This is a special case where the first delimiter match falls in the 13'th byte space
// ie, the last byte space of the read chunk, In this case the reader should peek in the next byte
// and replace with `\n`.
{"username,agexxbanana,12xxcarrot,23xxapple,34xxbrinjal,90xxraddish,45", "xx", 13},
}
for c, input := range inputs {
var readcsv []byte
var err error
delimitedReader := NewDelimitedReader(strings.NewReader(input.inputcsv), []rune(input.delimiter))
for err == nil {
chunk := make([]byte, input.chunkSize)
_, err = delimitedReader.Read(chunk)
readcsv = append(readcsv, chunk...)
}
if err != io.EOF {
t.Fatalf("Case %d: Error in delimited read", c+1)
}
expected := []byte(expected)
cleanCsv := removeNulls(readcsv)
if !bytes.Equal(cleanCsv, expected) {
t.Fatalf("Case %d: Expected the delimited csv to be `%s`, but instead found `%s`", c+1, string(expected), string(cleanCsv))
}
}
}
// Removes all the tailing nulls in chunks.
// Null chunks will be assigned if there is a reduction
// Eg, When `xv` is reduced to `\n`, the last byte is nullified.
func removeNulls(csv []byte) []byte {
cleanCsv := []byte{}
for _, p := range csv {
if p != 0 {
cleanCsv = append(cleanCsv, p)
}
}
return cleanCsv
}

190
pkg/s3select/csv/args.go Normal file
View File

@ -0,0 +1,190 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package csv
import (
"encoding/xml"
"fmt"
"strings"
)
const (
none = "none"
use = "use"
ignore = "ignore"
defaultRecordDelimiter = "\n"
defaultFieldDelimiter = ","
defaultQuoteCharacter = `"`
defaultQuoteEscapeCharacter = `"`
defaultCommentCharacter = "#"
always = "always"
asneeded = "asneeded"
)
// ReaderArgs - represents elements inside <InputSerialization><CSV> in request XML.
type ReaderArgs struct {
FileHeaderInfo string `xml:"FileHeaderInfo"`
RecordDelimiter string `xml:"RecordDelimiter"`
FieldDelimiter string `xml:"FieldDelimiter"`
QuoteCharacter string `xml:"QuoteCharacter"`
QuoteEscapeCharacter string `xml:"QuoteEscapeCharacter"`
CommentCharacter string `xml:"Comments"`
AllowQuotedRecordDelimiter bool `xml:"AllowQuotedRecordDelimiter"`
unmarshaled bool
}
// IsEmpty - returns whether reader args is empty or not.
func (args *ReaderArgs) IsEmpty() bool {
return !args.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (args *ReaderArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subReaderArgs ReaderArgs
parsedArgs := subReaderArgs{}
if err := d.DecodeElement(&parsedArgs, &start); err != nil {
return err
}
parsedArgs.FileHeaderInfo = strings.ToLower(parsedArgs.FileHeaderInfo)
switch parsedArgs.FileHeaderInfo {
case none, use, ignore:
default:
return errInvalidFileHeaderInfo(fmt.Errorf("invalid FileHeaderInfo '%v'", parsedArgs.FileHeaderInfo))
}
switch len(parsedArgs.RecordDelimiter) {
case 0:
parsedArgs.RecordDelimiter = defaultRecordDelimiter
case 1, 2:
default:
return fmt.Errorf("invalid RecordDelimiter '%v'", parsedArgs.RecordDelimiter)
}
switch len(parsedArgs.FieldDelimiter) {
case 0:
parsedArgs.FieldDelimiter = defaultFieldDelimiter
case 1:
default:
return fmt.Errorf("invalid FieldDelimiter '%v'", parsedArgs.FieldDelimiter)
}
switch parsedArgs.QuoteCharacter {
case "":
parsedArgs.QuoteCharacter = defaultQuoteCharacter
case defaultQuoteCharacter:
default:
return fmt.Errorf("unsupported QuoteCharacter '%v'", parsedArgs.QuoteCharacter)
}
switch parsedArgs.QuoteEscapeCharacter {
case "":
parsedArgs.QuoteEscapeCharacter = defaultQuoteEscapeCharacter
case defaultQuoteEscapeCharacter:
default:
return fmt.Errorf("unsupported QuoteEscapeCharacter '%v'", parsedArgs.QuoteEscapeCharacter)
}
switch parsedArgs.CommentCharacter {
case "":
parsedArgs.CommentCharacter = defaultCommentCharacter
case defaultCommentCharacter:
default:
return fmt.Errorf("unsupported Comments '%v'", parsedArgs.CommentCharacter)
}
if parsedArgs.AllowQuotedRecordDelimiter {
return fmt.Errorf("flag AllowQuotedRecordDelimiter is unsupported at the moment")
}
*args = ReaderArgs(parsedArgs)
args.unmarshaled = true
return nil
}
// WriterArgs - represents elements inside <OutputSerialization><CSV/> in request XML.
type WriterArgs struct {
QuoteFields string `xml:"QuoteFields"`
RecordDelimiter string `xml:"RecordDelimiter"`
FieldDelimiter string `xml:"FieldDelimiter"`
QuoteCharacter string `xml:"QuoteCharacter"`
QuoteEscapeCharacter string `xml:"QuoteEscapeCharacter"`
unmarshaled bool
}
// IsEmpty - returns whether writer args is empty or not.
func (args *WriterArgs) IsEmpty() bool {
return !args.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (args *WriterArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subWriterArgs WriterArgs
parsedArgs := subWriterArgs{}
if err := d.DecodeElement(&parsedArgs, &start); err != nil {
return err
}
parsedArgs.QuoteFields = strings.ToLower(parsedArgs.QuoteFields)
switch parsedArgs.QuoteFields {
case "":
parsedArgs.QuoteFields = asneeded
case always, asneeded:
default:
return errInvalidQuoteFields(fmt.Errorf("invalid QuoteFields '%v'", parsedArgs.QuoteFields))
}
switch len(parsedArgs.RecordDelimiter) {
case 0:
parsedArgs.RecordDelimiter = defaultRecordDelimiter
case 1, 2:
default:
return fmt.Errorf("invalid RecordDelimiter '%v'", parsedArgs.RecordDelimiter)
}
switch len(parsedArgs.FieldDelimiter) {
case 0:
parsedArgs.FieldDelimiter = defaultFieldDelimiter
case 1:
default:
return fmt.Errorf("invalid FieldDelimiter '%v'", parsedArgs.FieldDelimiter)
}
switch parsedArgs.QuoteCharacter {
case "":
parsedArgs.QuoteCharacter = defaultQuoteCharacter
case defaultQuoteCharacter:
default:
return fmt.Errorf("unsupported QuoteCharacter '%v'", parsedArgs.QuoteCharacter)
}
switch parsedArgs.QuoteEscapeCharacter {
case "":
parsedArgs.QuoteEscapeCharacter = defaultQuoteEscapeCharacter
case defaultQuoteEscapeCharacter:
default:
return fmt.Errorf("unsupported QuoteEscapeCharacter '%v'", parsedArgs.QuoteEscapeCharacter)
}
*args = WriterArgs(parsedArgs)
args.unmarshaled = true
return nil
}

View File

@ -0,0 +1,71 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package csv
type s3Error struct {
code string
message string
statusCode int
cause error
}
func (err *s3Error) Cause() error {
return err.cause
}
func (err *s3Error) ErrorCode() string {
return err.code
}
func (err *s3Error) ErrorMessage() string {
return err.message
}
func (err *s3Error) HTTPStatusCode() int {
return err.statusCode
}
func (err *s3Error) Error() string {
return err.message
}
func errInvalidFileHeaderInfo(err error) *s3Error {
return &s3Error{
code: "InvalidFileHeaderInfo",
message: "The FileHeaderInfo is invalid. Only NONE, USE, and IGNORE are supported.",
statusCode: 400,
cause: err,
}
}
func errInvalidQuoteFields(err error) *s3Error {
return &s3Error{
code: "InvalidQuoteFields",
message: "The QuoteFields is invalid. Only ALWAYS and ASNEEDED are supported.",
statusCode: 400,
cause: err,
}
}
func errCSVParsingError(err error) *s3Error {
return &s3Error{
code: "CSVParsingError",
message: "Encountered an error parsing the CSV file. Check the file and try again.",
statusCode: 400,
cause: err,
}
}

166
pkg/s3select/csv/reader.go Normal file
View File

@ -0,0 +1,166 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package csv
import (
"bytes"
"encoding/csv"
"fmt"
"io"
"github.com/minio/minio/pkg/s3select/sql"
)
type recordReader struct {
reader io.Reader
recordDelimiter []byte
oneByte []byte
useOneByte bool
}
func (rr *recordReader) Read(p []byte) (n int, err error) {
if rr.useOneByte {
p[0] = rr.oneByte[0]
rr.useOneByte = false
n, err = rr.reader.Read(p[1:])
n++
} else {
n, err = rr.reader.Read(p)
}
if err != nil {
return 0, err
}
if string(rr.recordDelimiter) == "\n" {
return n, nil
}
for {
i := bytes.Index(p, rr.recordDelimiter)
if i < 0 {
break
}
p[i] = '\n'
if len(rr.recordDelimiter) > 1 {
p = append(p[:i+1], p[i+len(rr.recordDelimiter):]...)
}
}
n = len(p)
if len(rr.recordDelimiter) == 1 || p[n-1] != rr.recordDelimiter[0] {
return n, nil
}
if _, err = rr.reader.Read(rr.oneByte); err != nil {
return 0, err
}
if rr.oneByte[0] == rr.recordDelimiter[1] {
p[n-1] = '\n'
return n, nil
}
rr.useOneByte = true
return n, nil
}
// Reader - CSV record reader for S3Select.
type Reader struct {
args *ReaderArgs
readCloser io.ReadCloser
csvReader *csv.Reader
columnNames []string
}
// Read - reads single record.
func (r *Reader) Read() (sql.Record, error) {
csvRecord, err := r.csvReader.Read()
if err != nil {
if err != io.EOF {
return nil, errCSVParsingError(err)
}
return nil, err
}
columnNames := r.columnNames
if columnNames == nil {
columnNames = make([]string, len(csvRecord))
for i := range csvRecord {
columnNames[i] = fmt.Sprintf("_%v", i+1)
}
}
nameIndexMap := make(map[string]int64)
for i := range columnNames {
nameIndexMap[columnNames[i]] = int64(i)
}
return &Record{
columnNames: columnNames,
csvRecord: csvRecord,
nameIndexMap: nameIndexMap,
}, nil
}
// Close - closes underlaying reader.
func (r *Reader) Close() error {
return r.readCloser.Close()
}
// NewReader - creates new CSV reader using readCloser.
func NewReader(readCloser io.ReadCloser, args *ReaderArgs) (*Reader, error) {
if args == nil || args.IsEmpty() {
panic(fmt.Errorf("empty args passed %v", args))
}
csvReader := csv.NewReader(&recordReader{
reader: readCloser,
recordDelimiter: []byte(args.RecordDelimiter),
oneByte: []byte{0},
})
csvReader.Comma = []rune(args.FieldDelimiter)[0]
csvReader.Comment = []rune(args.CommentCharacter)[0]
csvReader.FieldsPerRecord = -1
r := &Reader{
args: args,
readCloser: readCloser,
csvReader: csvReader,
}
if args.FileHeaderInfo == none {
return r, nil
}
record, err := csvReader.Read()
if err != nil {
if err != io.EOF {
return nil, errCSVParsingError(err)
}
return nil, err
}
if args.FileHeaderInfo == use {
r.columnNames = record
}
return r, nil
}

View File

@ -0,0 +1,95 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package csv
import (
"bytes"
"encoding/csv"
"fmt"
"github.com/minio/minio/pkg/s3select/sql"
"github.com/tidwall/sjson"
)
// Record - is CSV record.
type Record struct {
columnNames []string
csvRecord []string
nameIndexMap map[string]int64
}
// Get - gets the value for a column name.
func (r *Record) Get(name string) (*sql.Value, error) {
index, found := r.nameIndexMap[name]
if !found {
return nil, fmt.Errorf("column %v not found", name)
}
if index >= int64(len(r.csvRecord)) {
// No value found for column 'name', hence return empty string for compatibility.
return sql.NewString(""), nil
}
return sql.NewString(r.csvRecord[index]), nil
}
// Set - sets the value for a column name.
func (r *Record) Set(name string, value *sql.Value) error {
r.columnNames = append(r.columnNames, name)
r.csvRecord = append(r.csvRecord, value.CSVString())
return nil
}
// MarshalCSV - encodes to CSV data.
func (r *Record) MarshalCSV(fieldDelimiter rune) ([]byte, error) {
buf := new(bytes.Buffer)
w := csv.NewWriter(buf)
w.Comma = fieldDelimiter
if err := w.Write(r.csvRecord); err != nil {
return nil, err
}
w.Flush()
if err := w.Error(); err != nil {
return nil, err
}
data := buf.Bytes()
return data[:len(data)-1], nil
}
// MarshalJSON - encodes to JSON data.
func (r *Record) MarshalJSON() ([]byte, error) {
data := "{}"
var err error
for i := len(r.columnNames) - 1; i >= 0; i-- {
if i >= len(r.csvRecord) {
continue
}
if data, err = sjson.Set(data, r.columnNames[i], r.csvRecord[i]); err != nil {
return nil, err
}
}
return []byte(data), nil
}
// NewRecord - creates new CSV record.
func NewRecord() *Record {
return &Record{}
}

View File

@ -1,110 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"encoding/xml"
)
// CSVFileHeaderInfo -Can be either USE IGNORE OR NONE, defines what to do with
// the first row
type CSVFileHeaderInfo string
// Constants for file header info.
const (
CSVFileHeaderInfoNone CSVFileHeaderInfo = "NONE"
CSVFileHeaderInfoIgnore = "IGNORE"
CSVFileHeaderInfoUse = "USE"
)
// The maximum character per record is set to be 1 MB.
const (
MaxCharsPerRecord = 1000000
)
// SelectCompressionType - ONLY GZIP is supported
type SelectCompressionType string
// JSONType determines json input serialization type.
type JSONType string
// Constants for compression types under select API.
const (
SelectCompressionNONE SelectCompressionType = "NONE"
SelectCompressionGZIP = "GZIP"
SelectCompressionBZIP = "BZIP2"
)
// CSVQuoteFields - Can be either Always or AsNeeded
type CSVQuoteFields string
// Constants for csv quote styles.
const (
CSVQuoteFieldsAlways CSVQuoteFields = "Always"
CSVQuoteFieldsAsNeeded = "AsNeeded"
)
// QueryExpressionType - Currently can only be SQL
type QueryExpressionType string
// Constants for expression type.
const (
QueryExpressionTypeSQL QueryExpressionType = "SQL"
)
// Constants for JSONTypes.
const (
JSONTypeDocument JSONType = "DOCUMENT"
JSONLinesType = "LINES"
)
// ObjectSelectRequest - represents the input select body
type ObjectSelectRequest struct {
XMLName xml.Name `xml:"SelectObjectContentRequest" json:"-"`
Expression string
ExpressionType QueryExpressionType
InputSerialization struct {
CompressionType SelectCompressionType
Parquet *struct{}
CSV *struct {
FileHeaderInfo CSVFileHeaderInfo
RecordDelimiter string
FieldDelimiter string
QuoteCharacter string
QuoteEscapeCharacter string
Comments string
}
JSON *struct {
Type JSONType
}
}
OutputSerialization struct {
CSV *struct {
QuoteFields CSVQuoteFields
RecordDelimiter string
FieldDelimiter string
QuoteCharacter string
QuoteEscapeCharacter string
}
JSON *struct {
RecordDelimiter string
}
}
RequestProgress struct {
Enabled bool
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Minio Cloud Storage, (C) 2018 Minio, Inc. * Minio Cloud Storage, (C) 2019 Minio, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,456 +16,111 @@
package s3select package s3select
import ( // SelectError - represents s3 select error specified in
"errors" // https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html#RESTObjectSELECTContent-responses-special-errors.
type SelectError interface {
"github.com/minio/minio/pkg/s3select/format" Cause() error
) ErrorCode() string
ErrorMessage() string
//S3 errors below HTTPStatusCode() int
Error() string
// ErrBusy is an error if the service is too busy. }
var ErrBusy = errors.New("The service is unavailable. Please retry")
type s3Error struct {
// ErrUnauthorizedAccess is an error if you lack the appropriate credentials to code string
// access the object. message string
var ErrUnauthorizedAccess = errors.New("You are not authorized to perform this operation") statusCode int
cause error
// ErrExpressionTooLong is an error if your SQL expression too long for }
// processing.
var ErrExpressionTooLong = errors.New("The SQL expression is too long: The maximum byte-length for the SQL expression is 256 KB") func (err *s3Error) Cause() error {
return err.cause
// ErrIllegalSQLFunctionArgument is an error if you provide an illegal argument }
// in the SQL function.
var ErrIllegalSQLFunctionArgument = errors.New("Illegal argument was used in the SQL function") func (err *s3Error) ErrorCode() string {
return err.code
// ErrInvalidKeyPath is an error if you provide a key in the SQL expression that }
// is invalid.
var ErrInvalidKeyPath = errors.New("Key path in the SQL expression is invalid") func (err *s3Error) ErrorMessage() string {
return err.message
// ErrColumnTooLong is an error if your query results in a column that is }
// greater than the max amount of characters per column of 1mb
var ErrColumnTooLong = errors.New("The length of a column in the result is greater than maxCharsPerColumn of 1 MB") func (err *s3Error) HTTPStatusCode() int {
return err.statusCode
// ErrOverMaxColumn is an error if the number of columns from the resulting }
// query is greater than 1Mb.
var ErrOverMaxColumn = errors.New("The number of columns in the result is greater than maxColumnNumber of 1 MB") func (err *s3Error) Error() string {
return err.message
// ErrOverMaxRecordSize is an error if the length of a record in the result is }
// greater than 1 Mb.
var ErrOverMaxRecordSize = errors.New("The length of a record in the result is greater than maxCharsPerRecord of 1 MB") func errMalformedXML(err error) *s3Error {
return &s3Error{
// ErrMissingHeaders is an error if some of the headers that are requested in code: "MalformedXML",
// the Select Query are not present in the file. message: "The XML provided was not well-formed or did not validate against our published schema. Check the service documentation and try again.",
var ErrMissingHeaders = errors.New("Some headers in the query are missing from the file. Check the file and try again") statusCode: 400,
cause: err,
// ErrInvalidCompressionFormat is an error if an unsupported compression type is }
// utilized with the select object query. }
var ErrInvalidCompressionFormat = errors.New("The file is not in a supported compression format. Only GZIP is supported at this time")
func errInvalidCompressionFormat(err error) *s3Error {
// ErrInvalidFileHeaderInfo is an error if the argument provided to the return &s3Error{
// FileHeader Argument is incorrect. code: "InvalidCompressionFormat",
var ErrInvalidFileHeaderInfo = errors.New("The FileHeaderInfo is invalid. Only NONE, USE, and IGNORE are supported") message: "The file is not in a supported compression format. Only GZIP and BZIP2 are supported.",
statusCode: 400,
// ErrInvalidJSONType is an error if the json format provided as an argument is cause: err,
// invalid. }
var ErrInvalidJSONType = errors.New("The JsonType is invalid. Only DOCUMENT and LINES are supported at this time") }
// ErrInvalidQuoteFields is an error if the arguments provided to the func errInvalidDataSource(err error) *s3Error {
// QuoteFields options are not valid. return &s3Error{
var ErrInvalidQuoteFields = errors.New("The QuoteFields is invalid. Only ALWAYS and ASNEEDED are supported") code: "InvalidDataSource",
message: "Invalid data source type. Only CSV, JSON, and Parquet are supported.",
// ErrInvalidRequestParameter is an error if the value of a parameter in the statusCode: 400,
// request element is not valid. cause: err,
var ErrInvalidRequestParameter = errors.New("The value of a parameter in Request element is invalid. Check the service API documentation and try again") }
}
// ErrExternalEvalException is an error that arises if the query can not be
// evaluated. func errInvalidRequestParameter(err error) *s3Error {
var ErrExternalEvalException = errors.New("The query cannot be evaluated. Check the file and try again") return &s3Error{
code: "InvalidRequestParameter",
// ErrInvalidDataType is an error that occurs if the SQL expression contains an message: "The value of a parameter in SelectRequest element is invalid. Check the service API documentation and try again.",
// invalid data type. statusCode: 400,
var ErrInvalidDataType = errors.New("The SQL expression contains an invalid data type") cause: err,
}
// ErrUnrecognizedFormatException is an error that arises if there is an invalid }
// record type.
var ErrUnrecognizedFormatException = errors.New("Encountered an invalid record type") func errObjectSerializationConflict(err error) *s3Error {
return &s3Error{
// ErrInvalidTextEncoding is an error if the text encoding is not valid. code: "ObjectSerializationConflict",
var ErrInvalidTextEncoding = errors.New("Invalid encoding type. Only UTF-8 encoding is supported at this time") message: "InputSerialization specifies more than one format (CSV, JSON, or Parquet), or OutputSerialization specifies more than one format (CSV or JSON). InputSerialization and OutputSerialization can only specify one format each.",
statusCode: 400,
// ErrInvalidTableAlias is an error that arises if the table alias provided in cause: err,
// the SQL expression is invalid. }
var ErrInvalidTableAlias = errors.New("The SQL expression contains an invalid table alias") }
// ErrMultipleDataSourcesUnsupported is an error that arises if multiple data func errInvalidExpressionType(err error) *s3Error {
// sources are provided. return &s3Error{
var ErrMultipleDataSourcesUnsupported = errors.New("Multiple data sources are not supported") code: "InvalidExpressionType",
message: "The ExpressionType is invalid. Only SQL expressions are supported.",
// ErrMissingRequiredParameter is an error that arises if a required argument statusCode: 400,
// is omitted from the Request. cause: err,
var ErrMissingRequiredParameter = errors.New("The Request entity is missing a required parameter. Check the service documentation and try again") }
}
// ErrObjectSerializationConflict is an error that arises if an unsupported
// output seralization is provided. func errMissingRequiredParameter(err error) *s3Error {
var ErrObjectSerializationConflict = errors.New("The Request entity can only contain one of CSV or JSON. Check the service documentation and try again") return &s3Error{
code: "MissingRequiredParameter",
// ErrUnsupportedSQLOperation is an error that arises if an unsupported SQL message: "The SelectRequest entity is missing a required parameter. Check the service documentation and try again.",
// operation is used. statusCode: 400,
var ErrUnsupportedSQLOperation = errors.New("Encountered an unsupported SQL operation") cause: err,
}
// ErrUnsupportedSQLStructure is an error that occurs if an unsupported SQL }
// structure is used.
var ErrUnsupportedSQLStructure = errors.New("Encountered an unsupported SQL structure. Check the SQL Reference") func errTruncatedInput(err error) *s3Error {
return &s3Error{
// ErrUnsupportedStorageClass is an error that occurs if an invalid storace code: "TruncatedInput",
// class is present. message: "Object decompression failed. Check that the object is properly compressed using the format specified in the request.",
var ErrUnsupportedStorageClass = errors.New("Encountered an invalid storage class. Only STANDARD, STANDARD_IA, and ONEZONE_IA storage classes are supported at this time") statusCode: 400,
cause: err,
// ErrUnsupportedSyntax is an error that occurs if invalid syntax is present in }
// the query.
var ErrUnsupportedSyntax = errors.New("Encountered invalid syntax")
// ErrUnsupportedRangeHeader is an error that occurs if a range header is
// provided.
var ErrUnsupportedRangeHeader = errors.New("Range header is not supported for this operation")
// ErrLexerInvalidChar is an error that occurs if the SQL expression contains an
// invalid character.
var ErrLexerInvalidChar = errors.New("The SQL expression contains an invalid character")
// ErrLexerInvalidOperator is an error that occurs if an invalid operator is
// used.
var ErrLexerInvalidOperator = errors.New("The SQL expression contains an invalid operator")
// ErrLexerInvalidLiteral is an error that occurs if an invalid literal is used.
var ErrLexerInvalidLiteral = errors.New("The SQL expression contains an invalid literal")
// ErrLexerInvalidIONLiteral is an error that occurs if an invalid operator is
// used
var ErrLexerInvalidIONLiteral = errors.New("The SQL expression contains an invalid operator")
// ErrParseExpectedDatePart is an error that occurs if the date part is not
// found in the SQL expression.
var ErrParseExpectedDatePart = errors.New("Did not find the expected date part in the SQL expression")
// ErrParseExpectedKeyword is an error that occurs if the expected keyword was
// not found in the expression.
var ErrParseExpectedKeyword = errors.New("Did not find the expected keyword in the SQL expression")
// ErrParseExpectedTokenType is an error that occurs if the expected token is
// not found in the SQL expression.
var ErrParseExpectedTokenType = errors.New("Did not find the expected token in the SQL expression")
// ErrParseExpected2TokenTypes is an error that occurs if 2 token types are not
// found.
var ErrParseExpected2TokenTypes = errors.New("Did not find the expected token in the SQL expression")
// ErrParseExpectedNumber is an error that occurs if a number is expected but
// not found in the expression.
var ErrParseExpectedNumber = errors.New("Did not find the expected number in the SQL expression")
// ErrParseExpectedRightParenBuiltinFunctionCall is an error that occurs if a
// right parenthesis is missing.
var ErrParseExpectedRightParenBuiltinFunctionCall = errors.New("Did not find the expected right parenthesis character in the SQL expression")
// ErrParseExpectedTypeName is an error that occurs if a type name is expected
// but not found.
var ErrParseExpectedTypeName = errors.New("Did not find the expected type name in the SQL expression")
// ErrParseExpectedWhenClause is an error that occurs if a When clause is
// expected but not found.
var ErrParseExpectedWhenClause = errors.New("Did not find the expected WHEN clause in the SQL expression. CASE is not supported")
// ErrParseUnsupportedToken is an error that occurs if the SQL expression
// contains an unsupported token.
var ErrParseUnsupportedToken = errors.New("The SQL expression contains an unsupported token")
// ErrParseUnsupportedLiteralsGroupBy is an error that occurs if the SQL
// expression has an unsupported use of Group By.
var ErrParseUnsupportedLiteralsGroupBy = errors.New("The SQL expression contains an unsupported use of GROUP BY")
// ErrParseExpectedMember is an error that occurs if there is an unsupported use
// of member in the SQL expression.
var ErrParseExpectedMember = errors.New("The SQL expression contains an unsupported use of MEMBER")
// ErrParseUnsupportedSelect is an error that occurs if there is an unsupported
// use of Select.
var ErrParseUnsupportedSelect = errors.New("The SQL expression contains an unsupported use of SELECT")
// ErrParseUnsupportedCase is an error that occurs if there is an unsupported
// use of case.
var ErrParseUnsupportedCase = errors.New("The SQL expression contains an unsupported use of CASE")
// ErrParseUnsupportedCaseClause is an error that occurs if there is an
// unsupported use of case.
var ErrParseUnsupportedCaseClause = errors.New("The SQL expression contains an unsupported use of CASE")
// ErrParseUnsupportedAlias is an error that occurs if there is an unsupported
// use of Alias.
var ErrParseUnsupportedAlias = errors.New("The SQL expression contains an unsupported use of ALIAS")
// ErrParseUnsupportedSyntax is an error that occurs if there is an
// UnsupportedSyntax in the SQL expression.
var ErrParseUnsupportedSyntax = errors.New("The SQL expression contains unsupported syntax")
// ErrParseUnknownOperator is an error that occurs if there is an invalid
// operator present in the SQL expression.
var ErrParseUnknownOperator = errors.New("The SQL expression contains an invalid operator")
// ErrParseMissingIdentAfterAt is an error that occurs if the wrong symbol
// follows the "@" symbol in the SQL expression.
var ErrParseMissingIdentAfterAt = errors.New("Did not find the expected identifier after the @ symbol in the SQL expression")
// ErrParseUnexpectedOperator is an error that occurs if the SQL expression
// contains an unexpected operator.
var ErrParseUnexpectedOperator = errors.New("The SQL expression contains an unexpected operator")
// ErrParseUnexpectedTerm is an error that occurs if the SQL expression contains
// an unexpected term.
var ErrParseUnexpectedTerm = errors.New("The SQL expression contains an unexpected term")
// ErrParseUnexpectedToken is an error that occurs if the SQL expression
// contains an unexpected token.
var ErrParseUnexpectedToken = errors.New("The SQL expression contains an unexpected token")
// ErrParseUnexpectedKeyword is an error that occurs if the SQL expression
// contains an unexpected keyword.
var ErrParseUnexpectedKeyword = errors.New("The SQL expression contains an unexpected keyword")
// ErrParseExpectedExpression is an error that occurs if the SQL expression is
// not found.
var ErrParseExpectedExpression = errors.New("Did not find the expected SQL expression")
// ErrParseExpectedLeftParenAfterCast is an error that occurs if the left
// parenthesis is missing after a cast in the SQL expression.
var ErrParseExpectedLeftParenAfterCast = errors.New("Did not find the expected left parenthesis after CAST in the SQL expression")
// ErrParseExpectedLeftParenValueConstructor is an error that occurs if the left
// parenthesis is not found in the SQL expression.
var ErrParseExpectedLeftParenValueConstructor = errors.New("Did not find expected the left parenthesis in the SQL expression")
// ErrParseExpectedLeftParenBuiltinFunctionCall is an error that occurs if the
// left parenthesis is not found in the SQL expression function call.
var ErrParseExpectedLeftParenBuiltinFunctionCall = errors.New("Did not find the expected left parenthesis in the SQL expression")
// ErrParseExpectedArgumentDelimiter is an error that occurs if the argument
// delimiter for the SQL expression is not provided.
var ErrParseExpectedArgumentDelimiter = errors.New("Did not find the expected argument delimiter in the SQL expression")
// ErrParseCastArity is an error that occurs because the CAST has incorrect
// arity.
var ErrParseCastArity = errors.New("The SQL expression CAST has incorrect arity")
// ErrParseInvalidTypeParam is an error that occurs because there is an invalid
// parameter value.
var ErrParseInvalidTypeParam = errors.New("The SQL expression contains an invalid parameter value")
// ErrParseEmptySelect is an error that occurs because the SQL expression
// contains an empty Select
var ErrParseEmptySelect = errors.New("The SQL expression contains an empty SELECT")
// ErrParseSelectMissingFrom is an error that occurs because there is a missing
// From after the Select List.
var ErrParseSelectMissingFrom = errors.New("The SQL expression contains a missing FROM after SELECT list")
// ErrParseExpectedIdentForGroupName is an error that occurs because Group is
// not supported in the SQL expression.
var ErrParseExpectedIdentForGroupName = errors.New("GROUP is not supported in the SQL expression")
// ErrParseExpectedIdentForAlias is an error that occurs if expected identifier
// for alias is not in the SQL expression.
var ErrParseExpectedIdentForAlias = errors.New("Did not find the expected identifier for the alias in the SQL expression")
// ErrParseUnsupportedCallWithStar is an error that occurs if COUNT is used with
// an argument other than "*".
var ErrParseUnsupportedCallWithStar = errors.New("Only COUNT with (*) as a parameter is supported in the SQL expression")
// ErrParseNonUnaryAgregateFunctionCall is an error that occurs if more than one
// argument is provided as an argument for aggregation functions.
var ErrParseNonUnaryAgregateFunctionCall = errors.New("Only one argument is supported for aggregate functions in the SQL expression")
// ErrParseMalformedJoin is an error that occurs if a "join" operation is
// attempted in the SQL expression as this is not supported.
var ErrParseMalformedJoin = errors.New("JOIN is not supported in the SQL expression")
// ErrParseExpectedIdentForAt is an error that occurs if after "AT" an Alias
// identifier is not provided.
var ErrParseExpectedIdentForAt = errors.New("Did not find the expected identifier for AT name in the SQL expression")
// ErrParseAsteriskIsNotAloneInSelectList is an error that occurs if in addition
// to an asterix, more column names are provided as arguments in the SQL
// expression.
var ErrParseAsteriskIsNotAloneInSelectList = errors.New("Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression")
// ErrParseCannotMixSqbAndWildcardInSelectList is an error that occurs if list
// indexing and an asterix are mixed in the SQL expression.
var ErrParseCannotMixSqbAndWildcardInSelectList = errors.New("Cannot mix [] and * in the same expression in a SELECT list in SQL expression")
// ErrParseInvalidContextForWildcardInSelectList is an error that occurs if the
// asterix is used improperly within the SQL expression.
var ErrParseInvalidContextForWildcardInSelectList = errors.New("Invalid use of * in SELECT list in the SQL expression")
// ErrEvaluatorBindingDoesNotExist is an error that occurs if a column name or
// path provided in the expression does not exist.
var ErrEvaluatorBindingDoesNotExist = errors.New("A column name or a path provided does not exist in the SQL expression")
// ErrIncorrectSQLFunctionArgumentType is an error that occurs if the wrong
// argument is provided to a SQL function.
var ErrIncorrectSQLFunctionArgumentType = errors.New("Incorrect type of arguments in function call in the SQL expression")
// ErrAmbiguousFieldName is an error that occurs if the column name which is not
// case sensitive, is not descriptive enough to retrieve a singular column.
var ErrAmbiguousFieldName = errors.New("Field name matches to multiple fields in the file. Check the SQL expression and the file, and try again")
// ErrEvaluatorInvalidArguments is an error that occurs if there are not the
// correct number of arguments in a functional call to a SQL expression.
var ErrEvaluatorInvalidArguments = errors.New("Incorrect number of arguments in the function call in the SQL expression")
// ErrValueParseFailure is an error that occurs if the Time Stamp is not parsed
// correctly in the SQL expression.
var ErrValueParseFailure = errors.New("Time stamp parse failure in the SQL expression")
// ErrIntegerOverflow is an error that occurs if there is an IntegerOverflow or
// IntegerUnderFlow in the SQL expression.
var ErrIntegerOverflow = errors.New("Int overflow or underflow in the SQL expression")
// ErrLikeInvalidInputs is an error that occurs if invalid inputs are provided
// to the argument LIKE Clause.
var ErrLikeInvalidInputs = errors.New("Invalid argument given to the LIKE clause in the SQL expression")
// ErrCastFailed occurs if the attempt to convert data types in the cast is not
// done correctly.
var ErrCastFailed = errors.New("Attempt to convert from one data type to another using CAST failed in the SQL expression")
// ErrInvalidCast is an error that occurs if the attempt to convert data types
// failed and was done in an improper fashion.
var ErrInvalidCast = errors.New("Attempt to convert from one data type to another using CAST failed in the SQL expression")
// ErrEvaluatorInvalidTimestampFormatPattern is an error that occurs if the Time
// Stamp Format needs more additional fields to be filled.
var ErrEvaluatorInvalidTimestampFormatPattern = errors.New("Time stamp format pattern requires additional fields in the SQL expression")
// ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing is an error that
// occurs if the format of the time stamp can not be parsed.
var ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing = errors.New("Time stamp format pattern contains a valid format symbol that cannot be applied to time stamp parsing in the SQL expression")
// ErrEvaluatorTimestampFormatPatternDuplicateFields is an error that occurs if
// the time stamp format pattern contains multiple format specifications which
// can not be clearly resolved.
var ErrEvaluatorTimestampFormatPatternDuplicateFields = errors.New("Time stamp format pattern contains multiple format specifiers representing the time stamp field in the SQL expression")
//ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch is an error that
//occurs if the time stamp format pattern contains a 12 hour day of format but
//does not have an AM/PM field.
var ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch = errors.New("Time stamp format pattern contains a 12-hour hour of day format symbol but doesn't also contain an AM/PM field, or it contains a 24-hour hour of day format specifier and contains an AM/PM field in the SQL expression")
// ErrEvaluatorUnterminatedTimestampFormatPatternToken is an error that occurs
// if there is an unterminated token in the SQL expression for time stamp
// format.
var ErrEvaluatorUnterminatedTimestampFormatPatternToken = errors.New("Time stamp format pattern contains unterminated token in the SQL expression")
// ErrEvaluatorInvalidTimestampFormatPatternToken is an error that occurs if
// there is an invalid token in the time stamp format within the SQL expression.
var ErrEvaluatorInvalidTimestampFormatPatternToken = errors.New("Time stamp format pattern contains an invalid token in the SQL expression")
// ErrEvaluatorInvalidTimestampFormatPatternSymbol is an error that occurs if
// the time stamp format pattern has an invalid symbol within the SQL
// expression.
var ErrEvaluatorInvalidTimestampFormatPatternSymbol = errors.New("Time stamp format pattern contains an invalid symbol in the SQL expression")
// S3 select API errors - TODO fix the errors.
var errorCodeResponse = map[error]string{
ErrBusy: "Busy",
ErrUnauthorizedAccess: "UnauthorizedAccess",
ErrExpressionTooLong: "ExpressionTooLong",
ErrIllegalSQLFunctionArgument: "IllegalSqlFunctionArgument",
format.ErrInvalidColumnIndex: "InvalidColumnIndex",
ErrInvalidKeyPath: "InvalidKeyPath",
ErrColumnTooLong: "ColumnTooLong",
ErrOverMaxColumn: "OverMaxColumn",
ErrOverMaxRecordSize: "OverMaxRecordSize",
ErrMissingHeaders: "MissingHeaders",
ErrInvalidCompressionFormat: "InvalidCompressionFormat",
format.ErrTruncatedInput: "TruncatedInput",
ErrInvalidFileHeaderInfo: "InvalidFileHeaderInfo",
ErrInvalidJSONType: "InvalidJsonType",
ErrInvalidQuoteFields: "InvalidQuoteFields",
ErrInvalidRequestParameter: "InvalidRequestParameter",
format.ErrCSVParsingError: "CSVParsingError",
format.ErrJSONParsingError: "JSONParsingError",
ErrExternalEvalException: "ExternalEvalException",
ErrInvalidDataType: "InvalidDataType",
ErrUnrecognizedFormatException: "UnrecognizedFormatException",
ErrInvalidTextEncoding: "InvalidTextEncoding",
ErrInvalidTableAlias: "InvalidTableAlias",
ErrMultipleDataSourcesUnsupported: "MultipleDataSourcesUnsupported",
ErrMissingRequiredParameter: "MissingRequiredParameter",
ErrObjectSerializationConflict: "ObjectSerializationConflict",
ErrUnsupportedSQLOperation: "UnsupportedSqlOperation",
ErrUnsupportedSQLStructure: "UnsupportedSqlStructure",
ErrUnsupportedStorageClass: "UnsupportedStorageClass",
ErrUnsupportedSyntax: "UnsupportedSyntax",
ErrUnsupportedRangeHeader: "UnsupportedRangeHeader",
ErrLexerInvalidChar: "LexerInvalidChar",
ErrLexerInvalidOperator: "LexerInvalidOperator",
ErrLexerInvalidLiteral: "LexerInvalidLiteral",
ErrLexerInvalidIONLiteral: "LexerInvalidIONLiteral",
ErrParseExpectedDatePart: "ParseExpectedDatePart",
ErrParseExpectedKeyword: "ParseExpectedKeyword",
ErrParseExpectedTokenType: "ParseExpectedTokenType",
ErrParseExpected2TokenTypes: "ParseExpected2TokenTypes",
ErrParseExpectedNumber: "ParseExpectedNumber",
ErrParseExpectedRightParenBuiltinFunctionCall: "ParseExpectedRightParenBuiltinFunctionCall",
ErrParseExpectedTypeName: "ParseExpectedTypeName",
ErrParseExpectedWhenClause: "ParseExpectedWhenClause",
ErrParseUnsupportedToken: "ParseUnsupportedToken",
ErrParseUnsupportedLiteralsGroupBy: "ParseUnsupportedLiteralsGroupBy",
ErrParseExpectedMember: "ParseExpectedMember",
ErrParseUnsupportedSelect: "ParseUnsupportedSelect",
ErrParseUnsupportedCase: "ParseUnsupportedCase:",
ErrParseUnsupportedCaseClause: "ParseUnsupportedCaseClause",
ErrParseUnsupportedAlias: "ParseUnsupportedAlias",
ErrParseUnsupportedSyntax: "ParseUnsupportedSyntax",
ErrParseUnknownOperator: "ParseUnknownOperator",
format.ErrParseInvalidPathComponent: "ParseInvalidPathComponent",
ErrParseMissingIdentAfterAt: "ParseMissingIdentAfterAt",
ErrParseUnexpectedOperator: "ParseUnexpectedOperator",
ErrParseUnexpectedTerm: "ParseUnexpectedTerm",
ErrParseUnexpectedToken: "ParseUnexpectedToken",
ErrParseUnexpectedKeyword: "ParseUnexpectedKeyword",
ErrParseExpectedExpression: "ParseExpectedExpression",
ErrParseExpectedLeftParenAfterCast: "ParseExpectedLeftParenAfterCast",
ErrParseExpectedLeftParenValueConstructor: "ParseExpectedLeftParenValueConstructor",
ErrParseExpectedLeftParenBuiltinFunctionCall: "ParseExpectedLeftParenBuiltinFunctionCall",
ErrParseExpectedArgumentDelimiter: "ParseExpectedArgumentDelimiter",
ErrParseCastArity: "ParseCastArity",
ErrParseInvalidTypeParam: "ParseInvalidTypeParam",
ErrParseEmptySelect: "ParseEmptySelect",
ErrParseSelectMissingFrom: "ParseSelectMissingFrom",
ErrParseExpectedIdentForGroupName: "ParseExpectedIdentForGroupName",
ErrParseExpectedIdentForAlias: "ParseExpectedIdentForAlias",
ErrParseUnsupportedCallWithStar: "ParseUnsupportedCallWithStar",
ErrParseNonUnaryAgregateFunctionCall: "ParseNonUnaryAgregateFunctionCall",
ErrParseMalformedJoin: "ParseMalformedJoin",
ErrParseExpectedIdentForAt: "ParseExpectedIdentForAt",
ErrParseAsteriskIsNotAloneInSelectList: "ParseAsteriskIsNotAloneInSelectList",
ErrParseCannotMixSqbAndWildcardInSelectList: "ParseCannotMixSqbAndWildcardInSelectList",
ErrParseInvalidContextForWildcardInSelectList: "ParseInvalidContextForWildcardInSelectList",
ErrEvaluatorBindingDoesNotExist: "EvaluatorBindingDoesNotExist",
ErrIncorrectSQLFunctionArgumentType: "IncorrectSqlFunctionArgumentType",
ErrAmbiguousFieldName: "AmbiguousFieldName",
ErrEvaluatorInvalidArguments: "EvaluatorInvalidArguments",
ErrValueParseFailure: "ValueParseFailure",
ErrIntegerOverflow: "IntegerOverflow",
ErrLikeInvalidInputs: "LikeInvalidInputs",
ErrCastFailed: "CastFailed",
ErrInvalidCast: "Attempt to convert from one data type to another using CAST failed in the SQL expression.",
ErrEvaluatorInvalidTimestampFormatPattern: "EvaluatorInvalidTimestampFormatPattern",
ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing: "EvaluatorInvalidTimestampFormatPatternSymbolForParsing",
ErrEvaluatorTimestampFormatPatternDuplicateFields: "EvaluatorTimestampFormatPatternDuplicateFields",
ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch: "EvaluatorTimestampFormatPatternHourClockAmPmMismatch",
ErrEvaluatorUnterminatedTimestampFormatPatternToken: "EvaluatorUnterminatedTimestampFormatPatternToken",
ErrEvaluatorInvalidTimestampFormatPatternToken: "EvaluatorInvalidTimestampFormatPatternToken",
ErrEvaluatorInvalidTimestampFormatPatternSymbol: "EvaluatorInvalidTimestampFormatPatternSymbol",
} }

View File

@ -1,223 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"strings"
"github.com/tidwall/gjson"
"github.com/xwb1989/sqlparser"
"github.com/minio/minio/pkg/s3select/format"
)
// stringOps is a function which handles the case in a clause
// if there is a need to perform a string function
func stringOps(myFunc *sqlparser.FuncExpr, record []byte, myReturnVal string) string {
var value string
funcName := myFunc.Name.CompliantName()
switch tempArg := myFunc.Exprs[0].(type) {
case *sqlparser.AliasedExpr:
switch col := tempArg.Expr.(type) {
case *sqlparser.FuncExpr:
// myReturnVal is actually the tail recursive value being used in the eval func.
return applyStrFunc(gjson.Parse(myReturnVal), funcName)
case *sqlparser.ColName:
value = applyStrFunc(gjson.GetBytes(record, col.Name.CompliantName()), funcName)
case *sqlparser.SQLVal:
value = applyStrFunc(gjson.ParseBytes(col.Val), funcName)
}
}
return value
}
// coalOps is a function which decomposes a COALESCE func expr into its struct.
func coalOps(myFunc *sqlparser.FuncExpr, record []byte, myReturnVal string) string {
myArgs := make([]string, len(myFunc.Exprs))
for i, expr := range myFunc.Exprs {
switch tempArg := expr.(type) {
case *sqlparser.AliasedExpr:
switch col := tempArg.Expr.(type) {
case *sqlparser.FuncExpr:
// myReturnVal is actually the tail recursive value being used in the eval func.
return myReturnVal
case *sqlparser.ColName:
myArgs[i] = gjson.GetBytes(record, col.Name.CompliantName()).String()
case *sqlparser.SQLVal:
myArgs[i] = string(col.Val)
}
}
}
return processCoalNoIndex(myArgs)
}
// nullOps is a function which decomposes a NullIf func expr into its struct.
func nullOps(myFunc *sqlparser.FuncExpr, record []byte, myReturnVal string) string {
myArgs := make([]string, 2)
for i, expr := range myFunc.Exprs {
switch tempArg := expr.(type) {
case *sqlparser.AliasedExpr:
switch col := tempArg.Expr.(type) {
case *sqlparser.FuncExpr:
return myReturnVal
case *sqlparser.ColName:
myArgs[i] = gjson.GetBytes(record, col.Name.CompliantName()).String()
case *sqlparser.SQLVal:
myArgs[i] = string(col.Val)
}
}
}
if myArgs[0] == myArgs[1] {
return ""
}
return myArgs[0]
}
// isValidString is a function that ensures the
// current index is one with a StrFunc
func isValidFunc(myList []int, index int) bool {
if myList == nil {
return false
}
for _, i := range myList {
if i == index {
return true
}
}
return false
}
// processCoalNoIndex is a function which evaluates a given COALESCE clause.
func processCoalNoIndex(coalStore []string) string {
for _, coal := range coalStore {
if coal != "null" && coal != "missing" && coal != "" {
return coal
}
}
return "null"
}
// evaluateFuncExpr is a function that allows for tail recursive evaluation of
// nested function expressions
func evaluateFuncExpr(myVal *sqlparser.FuncExpr, myReturnVal string, record []byte) string {
if myVal == nil {
return myReturnVal
}
// retrieve all the relevant arguments of the function
var mySubFunc []*sqlparser.FuncExpr
mySubFunc = make([]*sqlparser.FuncExpr, len(myVal.Exprs))
for i, expr := range myVal.Exprs {
switch col := expr.(type) {
case *sqlparser.AliasedExpr:
switch temp := col.Expr.(type) {
case *sqlparser.FuncExpr:
mySubFunc[i] = temp
}
}
}
// Need to do tree recursion so as to explore all possible directions of the
// nested function recursion
for i := 0; i < len(mySubFunc); i++ {
if supportedString(myVal.Name.CompliantName()) {
if mySubFunc != nil {
return stringOps(myVal, record, evaluateFuncExpr(mySubFunc[i], myReturnVal, record))
}
return stringOps(myVal, record, myReturnVal)
} else if strings.ToUpper(myVal.Name.CompliantName()) == "NULLIF" {
if mySubFunc != nil {
return nullOps(myVal, record, evaluateFuncExpr(mySubFunc[i], myReturnVal, record))
}
return nullOps(myVal, record, myReturnVal)
} else if strings.ToUpper(myVal.Name.CompliantName()) == "COALESCE" {
if mySubFunc != nil {
return coalOps(myVal, record, evaluateFuncExpr(mySubFunc[i], myReturnVal, record))
}
return coalOps(myVal, record, myReturnVal)
}
}
return ""
}
// evaluateFuncErr is a function that flags errors in nested functions.
func evaluateFuncErr(myVal *sqlparser.FuncExpr, reader format.Select) error {
if myVal == nil {
return nil
}
if !supportedFunc(myVal.Name.CompliantName()) {
return ErrUnsupportedSQLOperation
}
for _, expr := range myVal.Exprs {
switch tempArg := expr.(type) {
case *sqlparser.StarExpr:
return ErrParseUnsupportedCallWithStar
case *sqlparser.AliasedExpr:
switch col := tempArg.Expr.(type) {
case *sqlparser.FuncExpr:
if err := evaluateFuncErr(col, reader); err != nil {
return err
}
case *sqlparser.ColName:
if err := reader.ColNameErrs([]string{col.Name.CompliantName()}); err != nil {
return err
}
}
}
}
return nil
}
// evaluateIsExpr is a function for evaluating expressions of the form "column is ...."
func evaluateIsExpr(myFunc *sqlparser.IsExpr, row []byte, alias string) (bool, error) {
getMyVal := func() (myVal string) {
switch myIs := myFunc.Expr.(type) {
// case for literal val
case *sqlparser.SQLVal:
myVal = string(myIs.Val)
// case for nested func val
case *sqlparser.FuncExpr:
myVal = evaluateFuncExpr(myIs, "", row)
// case for col val
case *sqlparser.ColName:
myVal = gjson.GetBytes(row, myIs.Name.CompliantName()).String()
}
return myVal
}
operator := strings.ToLower(myFunc.Operator)
switch operator {
case "is null":
return getMyVal() == "", nil
case "is not null":
return getMyVal() != "", nil
default:
return false, ErrUnsupportedSQLOperation
}
}
// supportedString is a function that checks whether the function is a supported
// string one
func supportedString(strFunc string) bool {
return format.StringInSlice(strings.ToUpper(strFunc), []string{"TRIM", "SUBSTRING", "CHAR_LENGTH", "CHARACTER_LENGTH", "LOWER", "UPPER"})
}
// supportedFunc is a function that checks whether the function is a supported
// S3 one.
func supportedFunc(strFunc string) bool {
return format.StringInSlice(strings.ToUpper(strFunc), []string{"TRIM", "SUBSTRING", "CHAR_LENGTH", "CHARACTER_LENGTH", "LOWER", "UPPER", "COALESCE", "NULLIF"})
}

View File

@ -1,339 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package csv
import (
"encoding/csv"
"encoding/xml"
"io"
"strconv"
"strings"
"github.com/tidwall/sjson"
"github.com/minio/minio/pkg/ioutil"
"github.com/minio/minio/pkg/s3select/format"
)
// Options options are passed to the underlying encoding/csv reader.
type Options struct {
// HasHeader when true, will treat the first row as a header row.
HasHeader bool
// RecordDelimiter is the string that records are delimited by.
RecordDelimiter string
// FieldDelimiter is the string that fields are delimited by.
FieldDelimiter string
// Comments is the string the first character of a line of
// text matches the comment character.
Comments string
// Name of the table that is used for querying
Name string
// ReadFrom is where the data will be read from.
ReadFrom io.Reader
// If true then we need to add gzip or bzip reader.
// to extract the csv.
Compressed string
// SQL expression meant to be evaluated.
Expression string
// Output CSV will be delimited by.
OutputFieldDelimiter string
// Output CSV record will be delimited by.
OutputRecordDelimiter string
// Size of incoming object
StreamSize int64
// Whether Header is "USE" or another
HeaderOpt bool
// Progress enabled, enable/disable progress messages.
Progress bool
// Output format type, supported values are CSV and JSON
OutputType format.Type
}
// cinput represents a record producing input from a formatted object.
type cinput struct {
options *Options
reader *csv.Reader
firstRow []string
header []string
minOutputLength int
stats struct {
BytesScanned int64
BytesReturned int64
BytesProcessed int64
}
}
// New sets up a new Input, the first row is read when this is run.
// If there is a problem with reading the first row, the error is returned.
// Otherwise, the returned reader can be reliably consumed with Read().
// until Read() return err.
func New(opts *Options) (format.Select, error) {
// DelimitedReader treats custom record delimiter like `\r\n`,`\r`,`ab` etc and replaces it with `\n`.
normalizedReader := ioutil.NewDelimitedReader(opts.ReadFrom, []rune(opts.RecordDelimiter))
reader := &cinput{
options: opts,
reader: csv.NewReader(normalizedReader),
}
reader.stats.BytesScanned = opts.StreamSize
reader.stats.BytesProcessed = 0
reader.stats.BytesReturned = 0
reader.firstRow = nil
reader.reader.FieldsPerRecord = -1
if reader.options.FieldDelimiter != "" {
reader.reader.Comma = rune(reader.options.FieldDelimiter[0])
}
if reader.options.Comments != "" {
reader.reader.Comment = rune(reader.options.Comments[0])
}
// QuoteCharacter - " (defaulted currently)
reader.reader.LazyQuotes = true
if err := reader.readHeader(); err != nil {
return nil, err
}
return reader, nil
}
// Replace the spaces in columnnames with underscores
func cleanHeader(columns []string) []string {
for i := range columns {
// Even if header name is specified, some CSV's
// might have column header names might be empty
// and non-empty. In such a scenario we prepare
// indexed value.
if columns[i] == "" {
columns[i] = "_" + strconv.Itoa(i)
}
columns[i] = strings.Replace(columns[i], " ", "_", -1)
}
return columns
}
// readHeader reads the header into the header variable if the header is present
// as the first row of the csv
func (reader *cinput) readHeader() error {
var readErr error
if reader.options.HasHeader {
reader.firstRow, readErr = reader.reader.Read()
if readErr != nil {
return format.ErrCSVParsingError
}
reader.header = cleanHeader(reader.firstRow)
reader.firstRow = nil
} else {
reader.firstRow, readErr = reader.reader.Read()
if readErr != nil {
return format.ErrCSVParsingError
}
reader.header = make([]string, len(reader.firstRow))
for i := range reader.firstRow {
reader.header[i] = "_" + strconv.Itoa(i)
}
}
reader.minOutputLength = len(reader.header)
return nil
}
// Progress - return true if progress was requested.
func (reader *cinput) Progress() bool {
return reader.options.Progress
}
// UpdateBytesProcessed - populates the bytes Processed
func (reader *cinput) UpdateBytesProcessed(size int64) {
reader.stats.BytesProcessed += size
}
// Read returns byte sequence
func (reader *cinput) Read() ([]byte, error) {
dec := reader.readRecord()
if dec != nil {
var data []byte
var err error
// Navigate column values in reverse order to preserve
// the input order for AWS S3 compatibility, because
// sjson adds json key/value pairs in first in last out
// fashion. This should be fixed in sjson ideally. Following
// work around is needed to circumvent this issue for now.
for i := len(dec) - 1; i >= 0; i-- {
data, err = sjson.SetBytes(data, reader.header[i], dec[i])
if err != nil {
return nil, err
}
}
return data, nil
}
return nil, nil
}
// OutputFieldDelimiter - returns the requested output field delimiter.
func (reader *cinput) OutputFieldDelimiter() string {
return reader.options.OutputFieldDelimiter
}
// OutputRecordDelimiter - returns the requested output record delimiter.
func (reader *cinput) OutputRecordDelimiter() string {
return reader.options.OutputRecordDelimiter
}
// HasHeader - returns true or false depending upon the header.
func (reader *cinput) HasHeader() bool {
return reader.options.HasHeader
}
// Expression - return the Select Expression for
func (reader *cinput) Expression() string {
return reader.options.Expression
}
// UpdateBytesReturned - updates the Bytes returned for
func (reader *cinput) UpdateBytesReturned(size int64) {
reader.stats.BytesReturned += size
}
// Header returns the header of the reader. Either the first row if a header
// set in the options, or c#, where # is the column number, starting with 0.
func (reader *cinput) Header() []string {
return reader.header
}
// readRecord reads a single record from the stream and it always returns successfully.
// If the record is empty, an empty []string is returned.
// Record expand to match the current row size, adding blank fields as needed.
// Records never return less then the number of fields in the first row.
// Returns nil on EOF
// In the event of a parse error due to an invalid record, it is logged, and
// an empty []string is returned with the number of fields in the first row,
// as if the record were empty.
//
// In general, this is a very tolerant of problems reader.
func (reader *cinput) readRecord() []string {
var row []string
var fileErr error
if reader.firstRow != nil {
row = reader.firstRow
reader.firstRow = nil
return row
}
row, fileErr = reader.reader.Read()
emptysToAppend := reader.minOutputLength - len(row)
if fileErr == io.EOF || fileErr == io.ErrClosedPipe {
return nil
} else if _, ok := fileErr.(*csv.ParseError); ok {
emptysToAppend = reader.minOutputLength
}
if emptysToAppend > 0 {
for counter := 0; counter < emptysToAppend; counter++ {
row = append(row, "")
}
}
return row
}
// CreateStatXML is the function which does the marshaling from the stat
// structs into XML so that the progress and stat message can be sent
func (reader *cinput) CreateStatXML() (string, error) {
if reader.options.Compressed == "NONE" {
reader.stats.BytesProcessed = reader.options.StreamSize
reader.stats.BytesScanned = reader.stats.BytesProcessed
}
out, err := xml.Marshal(&format.Stats{
BytesScanned: reader.stats.BytesScanned,
BytesProcessed: reader.stats.BytesProcessed,
BytesReturned: reader.stats.BytesReturned,
})
if err != nil {
return "", err
}
return xml.Header + string(out), nil
}
// CreateProgressXML is the function which does the marshaling from the progress
// structs into XML so that the progress and stat message can be sent
func (reader *cinput) CreateProgressXML() (string, error) {
if reader.options.HasHeader {
reader.stats.BytesProcessed += format.ProcessSize(reader.header)
}
if reader.options.Compressed == "NONE" {
reader.stats.BytesScanned = reader.stats.BytesProcessed
}
out, err := xml.Marshal(&format.Progress{
BytesScanned: reader.stats.BytesScanned,
BytesProcessed: reader.stats.BytesProcessed,
BytesReturned: reader.stats.BytesReturned,
})
if err != nil {
return "", err
}
return xml.Header + string(out), nil
}
// Type - return the data format type
func (reader *cinput) Type() format.Type {
return format.CSV
}
// OutputType - return the data format type
func (reader *cinput) OutputType() format.Type {
return reader.options.OutputType
}
// ColNameErrs is a function which makes sure that the headers are requested are
// present in the file otherwise it throws an error.
func (reader *cinput) ColNameErrs(columnNames []string) error {
for i := 0; i < len(columnNames); i++ {
if columnNames[i] == "" {
continue
}
if !format.IsInt(columnNames[i]) && !reader.options.HeaderOpt {
return format.ErrInvalidColumnIndex
}
if format.IsInt(columnNames[i]) {
tempInt, _ := strconv.Atoi(columnNames[i])
if tempInt > len(reader.Header()) || tempInt == 0 {
return format.ErrInvalidColumnIndex
}
} else {
if reader.options.HeaderOpt && !format.StringInSlice(columnNames[i], reader.Header()) {
return format.ErrParseInvalidPathComponent
}
}
}
return nil
}

View File

@ -1,38 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package format
import "errors"
// ErrTruncatedInput is an error if the object is not compressed properly and an
// error occurs during decompression.
var ErrTruncatedInput = errors.New("Object decompression failed. Check that the object is properly compressed using the format specified in the request")
// ErrCSVParsingError is an error if the CSV presents an error while being
// parsed.
var ErrCSVParsingError = errors.New("Encountered an Error parsing the CSV file. Check the file and try again")
// ErrInvalidColumnIndex is an error if you provide a column index which is not
// valid.
var ErrInvalidColumnIndex = errors.New("Column index in the SQL expression is invalid")
// ErrParseInvalidPathComponent is an error that occurs if there is an invalid
// path component.
var ErrParseInvalidPathComponent = errors.New("The SQL expression contains an invalid path component")
// ErrJSONParsingError is an error if while parsing the JSON an error arises.
var ErrJSONParsingError = errors.New("Encountered an error parsing the JSON file. Check the file and try again")

View File

@ -1,50 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package format
import "strconv"
// IsInt - returns a true or false, whether a string can
// be represented as an int.
func IsInt(s string) bool {
_, err := strconv.Atoi(s)
return err == nil
}
// StringInSlice - this function finds whether a string is in a list
func StringInSlice(x string, list []string) bool {
for _, y := range list {
if x == y {
return true
}
}
return false
}
// ProcessSize - this function processes size so that we can calculate bytes BytesProcessed.
func ProcessSize(myrecord []string) int64 {
if len(myrecord) > 0 {
var size int64
size = int64(len(myrecord)-1) + 1
for i := range myrecord {
size += int64(len(myrecord[i]))
}
return size
}
return 0
}

View File

@ -1,205 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"bufio"
"encoding/xml"
"io"
"github.com/minio/minio/pkg/s3select/format"
"github.com/tidwall/gjson"
)
// Options options are passed to the underlying encoding/json reader.
type Options struct {
// Name of the table that is used for querying
Name string
// ReadFrom is where the data will be read from.
ReadFrom io.Reader
// If true then we need to add gzip or bzip reader.
// to extract the csv.
Compressed string
// SQL expression meant to be evaluated.
Expression string
// Input record delimiter.
RecordDelimiter string
// Output CSV will be delimited by.
OutputFieldDelimiter string
// Output record delimiter.
OutputRecordDelimiter string
// Size of incoming object
StreamSize int64
// True if DocumentType is DOCUMENTS
DocumentType bool
// Progress enabled, enable/disable progress messages.
Progress bool
// Output format type, supported values are CSV and JSON
OutputType format.Type
}
// jinput represents a record producing input from a formatted file or pipe.
type jinput struct {
options *Options
reader *bufio.Reader
header []string
minOutputLength int
stats struct {
BytesScanned int64
BytesReturned int64
BytesProcessed int64
}
}
// New sets up a new, the first Json is read when this is run.
// If there is a problem with reading the first Json, the error is returned.
// Otherwise, the returned reader can be reliably consumed with jsonRead()
// until jsonRead() returns nil.
func New(opts *Options) (format.Select, error) {
reader := &jinput{
options: opts,
reader: bufio.NewReader(opts.ReadFrom),
}
reader.stats.BytesScanned = opts.StreamSize
reader.stats.BytesProcessed = 0
reader.stats.BytesReturned = 0
return reader, nil
}
// Progress - return true if progress was requested.
func (reader *jinput) Progress() bool {
return reader.options.Progress
}
// UpdateBytesProcessed - populates the bytes Processed
func (reader *jinput) UpdateBytesProcessed(size int64) {
reader.stats.BytesProcessed += size
}
// Read the file and returns
func (reader *jinput) Read() ([]byte, error) {
data, _, err := reader.reader.ReadLine()
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
err = nil
} else {
err = format.ErrJSONParsingError
}
}
if err == nil {
var header []string
gjson.ParseBytes(data).ForEach(func(key, value gjson.Result) bool {
header = append(header, key.String())
return true
})
reader.header = header
}
return data, err
}
// OutputFieldDelimiter - returns the delimiter specified in input request,
// for JSON output this value is empty, but does have a value when
// output type is CSV.
func (reader *jinput) OutputFieldDelimiter() string {
return reader.options.OutputFieldDelimiter
}
// OutputRecordDelimiter - returns the delimiter specified in input request, after each JSON record.
func (reader *jinput) OutputRecordDelimiter() string {
return reader.options.OutputRecordDelimiter
}
// HasHeader - returns true or false depending upon the header.
func (reader *jinput) HasHeader() bool {
return true
}
// Expression - return the Select Expression for
func (reader *jinput) Expression() string {
return reader.options.Expression
}
// UpdateBytesReturned - updates the Bytes returned for
func (reader *jinput) UpdateBytesReturned(size int64) {
reader.stats.BytesReturned += size
}
// Header returns a nil in case of
func (reader *jinput) Header() []string {
return reader.header
}
// CreateStatXML is the function which does the marshaling from the stat
// structs into XML so that the progress and stat message can be sent
func (reader *jinput) CreateStatXML() (string, error) {
if reader.options.Compressed == "NONE" {
reader.stats.BytesProcessed = reader.options.StreamSize
reader.stats.BytesScanned = reader.stats.BytesProcessed
}
out, err := xml.Marshal(&format.Stats{
BytesScanned: reader.stats.BytesScanned,
BytesProcessed: reader.stats.BytesProcessed,
BytesReturned: reader.stats.BytesReturned,
})
if err != nil {
return "", err
}
return xml.Header + string(out), nil
}
// CreateProgressXML is the function which does the marshaling from the progress
// structs into XML so that the progress and stat message can be sent
func (reader *jinput) CreateProgressXML() (string, error) {
if !(reader.options.Compressed != "NONE") {
reader.stats.BytesScanned = reader.stats.BytesProcessed
}
out, err := xml.Marshal(&format.Progress{
BytesScanned: reader.stats.BytesScanned,
BytesProcessed: reader.stats.BytesProcessed,
BytesReturned: reader.stats.BytesReturned,
})
if err != nil {
return "", err
}
return xml.Header + string(out), nil
}
// Type - return the data format type {
func (reader *jinput) Type() format.Type {
return format.JSON
}
// OutputType - return the data format type {
func (reader *jinput) OutputType() format.Type {
return reader.options.OutputType
}
// ColNameErrs - this is a dummy function for JSON input type.
func (reader *jinput) ColNameErrs(columnNames []string) error {
return nil
}

View File

@ -1,65 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package format
import "encoding/xml"
// Select Interface helper methods, implementing features needed for
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
type Select interface {
Type() Type
OutputType() Type
Read() ([]byte, error)
Header() []string
HasHeader() bool
OutputFieldDelimiter() string
OutputRecordDelimiter() string
UpdateBytesProcessed(int64)
Expression() string
UpdateBytesReturned(int64)
CreateStatXML() (string, error)
CreateProgressXML() (string, error)
ColNameErrs(columnNames []string) error
Progress() bool
}
// Progress represents a struct that represents the format for XML of the
// progress messages
type Progress struct {
XMLName xml.Name `xml:"Progress" json:"-"`
BytesScanned int64 `xml:"BytesScanned"`
BytesProcessed int64 `xml:"BytesProcessed"`
BytesReturned int64 `xml:"BytesReturned"`
}
// Stats represents a struct that represents the format for XML of the stat
// messages
type Stats struct {
XMLName xml.Name `xml:"Stats" json:"-"`
BytesScanned int64 `xml:"BytesScanned"`
BytesProcessed int64 `xml:"BytesProcessed"`
BytesReturned int64 `xml:"BytesReturned"`
}
// Type different types of support data format types.
type Type string
// Different data format types.
const (
JSON Type = "json"
CSV Type = "csv"
)

182
pkg/s3select/genmessage.go Normal file
View File

@ -0,0 +1,182 @@
// +build ignore
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"bytes"
"encoding/binary"
"fmt"
"hash/crc32"
)
func genRecordsHeader() {
buf := new(bytes.Buffer)
buf.WriteByte(13)
buf.WriteString(":message-type")
buf.WriteByte(7)
buf.Write([]byte{0, 5})
buf.WriteString("event")
buf.WriteByte(13)
buf.WriteString(":content-type")
buf.WriteByte(7)
buf.Write([]byte{0, 24})
buf.WriteString("application/octet-stream")
buf.WriteByte(11)
buf.WriteString(":event-type")
buf.WriteByte(7)
buf.Write([]byte{0, 7})
buf.WriteString("Records")
fmt.Println(buf.Bytes())
}
// Continuation Message
// ====================
// Header specification
// --------------------
// Continuation messages contain two headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-cont.png
//
// Payload specification
// ---------------------
// Continuation messages have no payload.
func genContinuationMessage() {
buf := new(bytes.Buffer)
buf.WriteByte(13)
buf.WriteString(":message-type")
buf.WriteByte(7)
buf.Write([]byte{0, 5})
buf.WriteString("event")
buf.WriteByte(11)
buf.WriteString(":event-type")
buf.WriteByte(7)
buf.Write([]byte{0, 4})
buf.WriteString("Cont")
header := buf.Bytes()
headerLength := len(header)
payloadLength := 0
totalLength := totalByteLength(headerLength, payloadLength)
buf = new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(totalLength))
binary.Write(buf, binary.BigEndian, uint32(headerLength))
prelude := buf.Bytes()
binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE(prelude))
buf.Write(header)
message := buf.Bytes()
binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE(message))
fmt.Println(buf.Bytes())
}
func genProgressHeader() {
buf := new(bytes.Buffer)
buf.WriteByte(13)
buf.WriteString(":message-type")
buf.WriteByte(7)
buf.Write([]byte{0, 5})
buf.WriteString("event")
buf.WriteByte(13)
buf.WriteString(":content-type")
buf.WriteByte(7)
buf.Write([]byte{0, 8})
buf.WriteString("text/xml")
buf.WriteByte(11)
buf.WriteString(":event-type")
buf.WriteByte(7)
buf.Write([]byte{0, 8})
buf.WriteString("Progress")
fmt.Println(buf.Bytes())
}
func genStatsHeader() {
buf := new(bytes.Buffer)
buf.WriteByte(13)
buf.WriteString(":message-type")
buf.WriteByte(7)
buf.Write([]byte{0, 5})
buf.WriteString("event")
buf.WriteByte(13)
buf.WriteString(":content-type")
buf.WriteByte(7)
buf.Write([]byte{0, 8})
buf.WriteString("text/xml")
buf.WriteByte(11)
buf.WriteString(":event-type")
buf.WriteByte(7)
buf.Write([]byte{0, 5})
buf.WriteString("Stats")
fmt.Println(buf.Bytes())
}
// End Message
// ===========
// Header specification
// --------------------
// End messages contain two headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-end.png
//
// Payload specification
// ---------------------
// End messages have no payload.
func genEndMessage() {
buf := new(bytes.Buffer)
buf.WriteByte(13)
buf.WriteString(":message-type")
buf.WriteByte(7)
buf.Write([]byte{0, 5})
buf.WriteString("event")
buf.WriteByte(11)
buf.WriteString(":event-type")
buf.WriteByte(7)
buf.Write([]byte{0, 3})
buf.WriteString("End")
header := buf.Bytes()
headerLength := len(header)
payloadLength := 0
totalLength := totalByteLength(headerLength, payloadLength)
buf = new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(totalLength))
binary.Write(buf, binary.BigEndian, uint32(headerLength))
prelude := buf.Bytes()
binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE(prelude))
buf.Write(header)
message := buf.Bytes()
binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE(message))
fmt.Println(buf.Bytes())
}

View File

@ -1,563 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"fmt"
"math"
"strconv"
"strings"
"github.com/minio/minio/pkg/s3select/format"
"github.com/tidwall/gjson"
"github.com/xwb1989/sqlparser"
)
// MaxExpressionLength - 256KiB
const MaxExpressionLength = 256 * 1024
// matchesMyWhereClause takes []byte, process the where clause and returns true if the row suffices
func matchesMyWhereClause(record []byte, alias string, whereClause sqlparser.Expr) (bool, error) {
var conversionColumn string
var operator string
var operand gjson.Result
if fmt.Sprintf("%v", whereClause) == "false" {
return false, nil
}
switch expr := whereClause.(type) {
case *sqlparser.IsExpr:
return evaluateIsExpr(expr, record, alias)
case *sqlparser.RangeCond:
operator = expr.Operator
if operator != "between" && operator != "not between" {
return false, ErrUnsupportedSQLOperation
}
result, err := evaluateBetween(expr, alias, record)
if err != nil {
return false, err
}
if operator == "not between" {
return !result, nil
}
return result, nil
case *sqlparser.ComparisonExpr:
operator = expr.Operator
switch right := expr.Right.(type) {
case *sqlparser.FuncExpr:
operand = gjson.Parse(evaluateFuncExpr(right, "", record))
case *sqlparser.SQLVal:
operand = gjson.ParseBytes(right.Val)
}
var myVal string
switch left := expr.Left.(type) {
case *sqlparser.FuncExpr:
myVal = evaluateFuncExpr(left, "", record)
conversionColumn = ""
case *sqlparser.ColName:
conversionColumn = left.Name.CompliantName()
}
if myVal != "" {
return evaluateOperator(gjson.Parse(myVal), operator, operand)
}
return evaluateOperator(gjson.GetBytes(record, conversionColumn), operator, operand)
case *sqlparser.AndExpr:
var leftVal bool
var rightVal bool
switch left := expr.Left.(type) {
case *sqlparser.ComparisonExpr:
temp, err := matchesMyWhereClause(record, alias, left)
if err != nil {
return false, err
}
leftVal = temp
}
switch right := expr.Right.(type) {
case *sqlparser.ComparisonExpr:
temp, err := matchesMyWhereClause(record, alias, right)
if err != nil {
return false, err
}
rightVal = temp
}
return (rightVal && leftVal), nil
case *sqlparser.OrExpr:
var leftVal bool
var rightVal bool
switch left := expr.Left.(type) {
case *sqlparser.ComparisonExpr:
leftVal, _ = matchesMyWhereClause(record, alias, left)
}
switch right := expr.Right.(type) {
case *sqlparser.ComparisonExpr:
rightVal, _ = matchesMyWhereClause(record, alias, right)
}
return (rightVal || leftVal), nil
}
return true, nil
}
func applyStrFunc(rawArg gjson.Result, funcName string) string {
switch strings.ToUpper(funcName) {
case "TRIM":
// parser has an issue which does not allow it to support
// Trim with other arguments
return strings.Trim(rawArg.String(), " ")
case "SUBSTRING":
// TODO: parser has an issue which does not support substring
return rawArg.String()
case "CHAR_LENGTH":
return strconv.Itoa(len(rawArg.String()))
case "CHARACTER_LENGTH":
return strconv.Itoa(len(rawArg.String()))
case "LOWER":
return strings.ToLower(rawArg.String())
case "UPPER":
return strings.ToUpper(rawArg.String())
}
return rawArg.String()
}
// evaluateBetween is a function which evaluates a Between Clause.
func evaluateBetween(betweenExpr *sqlparser.RangeCond, alias string, record []byte) (bool, error) {
var colToVal gjson.Result
var colFromVal gjson.Result
var conversionColumn string
var funcName string
switch colTo := betweenExpr.To.(type) {
case sqlparser.Expr:
switch colToMyVal := colTo.(type) {
case *sqlparser.FuncExpr:
colToVal = gjson.Parse(stringOps(colToMyVal, record, ""))
case *sqlparser.SQLVal:
colToVal = gjson.ParseBytes(colToMyVal.Val)
}
}
switch colFrom := betweenExpr.From.(type) {
case sqlparser.Expr:
switch colFromMyVal := colFrom.(type) {
case *sqlparser.FuncExpr:
colFromVal = gjson.Parse(stringOps(colFromMyVal, record, ""))
case *sqlparser.SQLVal:
colFromVal = gjson.ParseBytes(colFromMyVal.Val)
}
}
var myFuncVal string
switch left := betweenExpr.Left.(type) {
case *sqlparser.FuncExpr:
myFuncVal = evaluateFuncExpr(left, "", record)
conversionColumn = ""
case *sqlparser.ColName:
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
}
toGreater, err := evaluateOperator(colToVal, ">", colFromVal)
if err != nil {
return false, err
}
if toGreater {
return evalBetweenGreater(conversionColumn, record, funcName, colFromVal, colToVal, myFuncVal)
}
return evalBetweenLess(conversionColumn, record, funcName, colFromVal, colToVal, myFuncVal)
}
func evalBetween(conversionColumn string, record []byte, funcName string, colFromVal gjson.Result, colToVal gjson.Result, myColVal string, operator string) (bool, error) {
if format.IsInt(conversionColumn) {
myVal, err := evaluateOperator(gjson.GetBytes(record, "_"+conversionColumn), operator, colFromVal)
if err != nil {
return false, err
}
var myOtherVal bool
myOtherVal, err = evaluateOperator(colToVal, operator, gjson.GetBytes(record, "_"+conversionColumn))
if err != nil {
return false, err
}
return (myVal && myOtherVal), nil
}
if myColVal != "" {
myVal, err := evaluateOperator(gjson.Parse(myColVal), operator, colFromVal)
if err != nil {
return false, err
}
var myOtherVal bool
myOtherVal, err = evaluateOperator(colToVal, operator, gjson.Parse(myColVal))
if err != nil {
return false, err
}
return (myVal && myOtherVal), nil
}
myVal, err := evaluateOperator(gjson.GetBytes(record, conversionColumn), operator, colFromVal)
if err != nil {
return false, err
}
var myOtherVal bool
myOtherVal, err = evaluateOperator(colToVal, operator, gjson.GetBytes(record, conversionColumn))
if err != nil {
return false, err
}
return (myVal && myOtherVal), nil
}
// evalBetweenGreater is a function which evaluates the between given that the
// TO is > than the FROM.
func evalBetweenGreater(conversionColumn string, record []byte, funcName string, colFromVal gjson.Result, colToVal gjson.Result, myColVal string) (bool, error) {
return evalBetween(conversionColumn, record, funcName, colFromVal, colToVal, myColVal, ">=")
}
// evalBetweenLess is a function which evaluates the between given that the
// FROM is > than the TO.
func evalBetweenLess(conversionColumn string, record []byte, funcName string, colFromVal gjson.Result, colToVal gjson.Result, myColVal string) (bool, error) {
return evalBetween(conversionColumn, record, funcName, colFromVal, colToVal, myColVal, "<=")
}
// This is a really important function it actually evaluates the boolean
// statement and therefore actually returns a bool, it functions as the lowest
// level of the state machine.
func evaluateOperator(myTblVal gjson.Result, operator string, operand gjson.Result) (bool, error) {
if err := checkValidOperator(operator); err != nil {
return false, err
}
if !myTblVal.Exists() {
return false, nil
}
switch {
case operand.Type == gjson.String || operand.Type == gjson.Null:
return stringEval(myTblVal.String(), operator, operand.String())
case operand.Type == gjson.Number:
opInt := format.IsInt(operand.Raw)
tblValInt := format.IsInt(strings.Trim(myTblVal.Raw, "\""))
if opInt && tblValInt {
return intEval(int64(myTblVal.Float()), operator, operand.Int())
}
if !opInt && !tblValInt {
return floatEval(myTblVal.Float(), operator, operand.Float())
}
switch operator {
case "!=":
return true, nil
}
return false, nil
case myTblVal.Type != operand.Type:
return false, nil
default:
return false, ErrUnsupportedSyntax
}
}
// checkValidOperator ensures that the current operator is supported
func checkValidOperator(operator string) error {
listOfOps := []string{">", "<", "=", "<=", ">=", "!=", "like"}
for i := range listOfOps {
if operator == listOfOps[i] {
return nil
}
}
return ErrParseUnknownOperator
}
// stringEval is for evaluating the state of string comparison.
func stringEval(myRecordVal string, operator string, myOperand string) (bool, error) {
switch operator {
case ">":
return myRecordVal > myOperand, nil
case "<":
return myRecordVal < myOperand, nil
case "=":
return myRecordVal == myOperand, nil
case "<=":
return myRecordVal <= myOperand, nil
case ">=":
return myRecordVal >= myOperand, nil
case "!=":
return myRecordVal != myOperand, nil
case "like":
return likeConvert(myOperand, myRecordVal)
}
return false, ErrUnsupportedSyntax
}
// intEval is for evaluating integer comparisons.
func intEval(myRecordVal int64, operator string, myOperand int64) (bool, error) {
switch operator {
case ">":
return myRecordVal > myOperand, nil
case "<":
return myRecordVal < myOperand, nil
case "=":
return myRecordVal == myOperand, nil
case "<=":
return myRecordVal <= myOperand, nil
case ">=":
return myRecordVal >= myOperand, nil
case "!=":
return myRecordVal != myOperand, nil
}
return false, ErrUnsupportedSyntax
}
// floatEval is for evaluating the comparison of floats.
func floatEval(myRecordVal float64, operator string, myOperand float64) (bool, error) {
// Basically need some logic thats like, if the types dont match check for a cast
switch operator {
case ">":
return myRecordVal > myOperand, nil
case "<":
return myRecordVal < myOperand, nil
case "=":
return myRecordVal == myOperand, nil
case "<=":
return myRecordVal <= myOperand, nil
case ">=":
return myRecordVal >= myOperand, nil
case "!=":
return myRecordVal != myOperand, nil
}
return false, ErrUnsupportedSyntax
}
// prefixMatch allows for matching a prefix only like query e.g a%
func prefixMatch(pattern string, record string) bool {
for i := 0; i < len(pattern)-1; i++ {
if pattern[i] != record[i] && pattern[i] != byte('_') {
return false
}
}
return true
}
// suffixMatch allows for matching a suffix only like query e.g %an
func suffixMatch(pattern string, record string) bool {
for i := len(pattern) - 1; i > 0; i-- {
if pattern[i] != record[len(record)-(len(pattern)-i)] && pattern[i] != byte('_') {
return false
}
}
return true
}
// This function is for evaluating select statements which are case sensitive
func likeConvert(pattern string, record string) (bool, error) {
// If pattern is empty just return false
if pattern == "" || record == "" {
return false, nil
}
// for suffix match queries e.g %a
if len(pattern) >= 2 && pattern[0] == byte('%') && strings.Count(pattern, "%") == 1 {
return suffixMatch(pattern, record), nil
}
// for prefix match queries e.g a%
if len(pattern) >= 2 && pattern[len(pattern)-1] == byte('%') && strings.Count(pattern, "%") == 1 {
return prefixMatch(pattern, record), nil
}
charCount := 0
currPos := 0
// Loop through the pattern so that a boolean can be returned
for i := 0; i < len(pattern); i++ {
if pattern[i] == byte('_') {
// if its an underscore it can be anything so shift current position for
// pattern and string
charCount++
// if there have been more characters in the pattern than record, clearly
// there should be a return
if i != len(pattern)-1 {
if pattern[i+1] != byte('%') && pattern[i+1] != byte('_') {
if currPos != len(record)-1 && pattern[i+1] != record[currPos+1] {
return false, nil
}
}
}
if charCount > len(record) {
return false, nil
}
// if the pattern has been fully evaluated, then just return.
if len(pattern) == i+1 {
return true, nil
}
i++
currPos++
}
if pattern[i] == byte('%') || pattern[i] == byte('*') {
// if there is a wildcard then want to return true if its last and flag it.
if currPos == len(record) {
return false, nil
}
if i+1 == len(pattern) {
return true, nil
}
} else {
charCount++
matched := false
// iterate through the pattern and check if there is a match for the
// character
for currPos < len(record) {
if record[currPos] == pattern[i] || pattern[i] == byte('_') {
matched = true
break
}
currPos++
}
currPos++
// if the character did not match then return should occur.
if !matched {
return false, nil
}
}
}
if charCount > len(record) {
return false, nil
}
if currPos < len(record) {
return false, nil
}
return true, nil
}
// cleanCol cleans a column name from the parser so that the name is returned to
// original.
func cleanCol(myCol string, alias string) string {
if len(myCol) <= 0 {
return myCol
}
if !strings.HasPrefix(myCol, alias) && myCol[0] == '_' {
myCol = alias + myCol
}
if strings.Contains(myCol, ".") {
myCol = strings.Replace(myCol, alias+"._", "", len(myCol))
}
myCol = strings.Replace(myCol, alias+"_", "", len(myCol))
return myCol
}
// whereClauseNameErrs is a function which returns an error if there is a column
// in the where clause which does not exist.
func whereClauseNameErrs(whereClause interface{}, alias string, f format.Select) error {
var conversionColumn string
switch expr := whereClause.(type) {
// case for checking errors within a clause of the form "col_name is ..."
case *sqlparser.IsExpr:
switch myCol := expr.Expr.(type) {
case *sqlparser.FuncExpr:
if err := evaluateFuncErr(myCol, f); err != nil {
return err
}
case *sqlparser.ColName:
conversionColumn = cleanCol(myCol.Name.CompliantName(), alias)
}
case *sqlparser.RangeCond:
switch left := expr.Left.(type) {
case *sqlparser.FuncExpr:
if err := evaluateFuncErr(left, f); err != nil {
return err
}
case *sqlparser.ColName:
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
}
case *sqlparser.ComparisonExpr:
switch left := expr.Left.(type) {
case *sqlparser.FuncExpr:
if err := evaluateFuncErr(left, f); err != nil {
return err
}
case *sqlparser.ColName:
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
}
case *sqlparser.AndExpr:
switch left := expr.Left.(type) {
case *sqlparser.ComparisonExpr:
return whereClauseNameErrs(left, alias, f)
}
switch right := expr.Right.(type) {
case *sqlparser.ComparisonExpr:
return whereClauseNameErrs(right, alias, f)
}
case *sqlparser.OrExpr:
switch left := expr.Left.(type) {
case *sqlparser.ComparisonExpr:
return whereClauseNameErrs(left, alias, f)
}
switch right := expr.Right.(type) {
case *sqlparser.ComparisonExpr:
return whereClauseNameErrs(right, alias, f)
}
}
if conversionColumn != "" {
return f.ColNameErrs([]string{conversionColumn})
}
return nil
}
// aggFuncToStr converts an array of floats into a properly formatted string.
func aggFuncToStr(aggVals []float64, f format.Select) string {
// Define a number formatting function
numToStr := func(f float64) string {
if f == math.Trunc(f) {
return strconv.FormatInt(int64(f), 10)
}
return strconv.FormatFloat(f, 'f', 6, 64)
}
// Display all whole numbers in aggVals as integers
vals := make([]string, len(aggVals))
for i, v := range aggVals {
vals[i] = numToStr(v)
}
// Intersperse field delimiter
return strings.Join(vals, f.OutputFieldDelimiter())
}
// checkForDuplicates ensures we do not have an ambigious column name.
func checkForDuplicates(columns []string, columnsMap map[string]int) error {
for i, column := range columns {
columns[i] = strings.Replace(column, " ", "_", len(column))
if _, exist := columnsMap[columns[i]]; exist {
return ErrAmbiguousFieldName
}
columnsMap[columns[i]] = i
}
return nil
}
// parseErrs is the function which handles all the errors that could occur
// through use of function arguments such as column names in NULLIF
func parseErrs(columnNames []string, whereClause interface{}, alias string, myFuncs SelectFuncs, f format.Select) error {
// Below code cleans up column names.
processColumnNames(columnNames, alias, f)
if columnNames[0] != "*" {
if err := f.ColNameErrs(columnNames); err != nil {
return err
}
}
// Below code ensures the whereClause has no errors.
if whereClause != nil {
tempClause := whereClause
if err := whereClauseNameErrs(tempClause, alias, f); err != nil {
return err
}
}
for i := 0; i < len(myFuncs.funcExpr); i++ {
if myFuncs.funcExpr[i] == nil {
continue
}
if err := evaluateFuncErr(myFuncs.funcExpr[i], f); err != nil {
return err
}
}
return nil
}

View File

@ -1,224 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"bytes"
"compress/bzip2"
"io"
"net/http"
"strings"
"time"
humanize "github.com/dustin/go-humanize"
"github.com/klauspost/pgzip"
"github.com/minio/minio/pkg/s3select/format"
"github.com/minio/minio/pkg/s3select/format/csv"
"github.com/minio/minio/pkg/s3select/format/json"
)
const (
// progressTime is the time interval for which a progress message is sent.
progressTime time.Duration = 60 * time.Second
// continuationTime is the time interval for which a continuation message is
// sent.
continuationTime time.Duration = 5 * time.Second
)
// Row is a Struct for keeping track of key aspects of a row.
type Row struct {
record string
err error
}
// This function replaces "",'' with `` for the select parser
func cleanExpr(expr string) string {
r := strings.NewReplacer("\"", "`")
return r.Replace(expr)
}
// New - initialize new select format
func New(reader io.Reader, size int64, req ObjectSelectRequest) (s3s format.Select, err error) {
switch req.InputSerialization.CompressionType {
case SelectCompressionGZIP:
if reader, err = pgzip.NewReader(reader); err != nil {
return nil, format.ErrTruncatedInput
}
case SelectCompressionBZIP:
reader = bzip2.NewReader(reader)
}
// Initializating options for CSV
if req.InputSerialization.CSV != nil {
if req.InputSerialization.CSV.FileHeaderInfo == "" {
req.InputSerialization.CSV.FileHeaderInfo = CSVFileHeaderInfoNone
}
if req.InputSerialization.CSV.RecordDelimiter == "" {
req.InputSerialization.CSV.RecordDelimiter = "\n"
}
options := &csv.Options{
Name: "S3Object", // Default table name for all objects
HasHeader: req.InputSerialization.CSV.FileHeaderInfo == CSVFileHeaderInfoUse,
RecordDelimiter: req.InputSerialization.CSV.RecordDelimiter,
FieldDelimiter: req.InputSerialization.CSV.FieldDelimiter,
Comments: req.InputSerialization.CSV.Comments,
ReadFrom: reader,
Compressed: string(req.InputSerialization.CompressionType),
Expression: cleanExpr(req.Expression),
StreamSize: size,
HeaderOpt: req.InputSerialization.CSV.FileHeaderInfo == CSVFileHeaderInfoUse,
Progress: req.RequestProgress.Enabled,
}
if req.OutputSerialization.CSV != nil {
if req.OutputSerialization.CSV.FieldDelimiter == "" {
req.OutputSerialization.CSV.FieldDelimiter = ","
}
options.OutputFieldDelimiter = req.OutputSerialization.CSV.FieldDelimiter
options.OutputRecordDelimiter = req.OutputSerialization.CSV.RecordDelimiter
options.OutputType = format.CSV
}
if req.OutputSerialization.JSON != nil {
options.OutputRecordDelimiter = req.OutputSerialization.JSON.RecordDelimiter
options.OutputType = format.JSON
}
// Initialize CSV input type
s3s, err = csv.New(options)
} else if req.InputSerialization.JSON != nil {
options := &json.Options{
Name: "S3Object", // Default table name for all objects
ReadFrom: reader,
Compressed: string(req.InputSerialization.CompressionType),
Expression: cleanExpr(req.Expression),
StreamSize: size,
DocumentType: req.InputSerialization.JSON.Type == JSONTypeDocument,
Progress: req.RequestProgress.Enabled,
}
if req.OutputSerialization.JSON != nil {
options.OutputRecordDelimiter = req.OutputSerialization.JSON.RecordDelimiter
options.OutputType = format.JSON
}
if req.OutputSerialization.CSV != nil {
options.OutputFieldDelimiter = req.OutputSerialization.CSV.FieldDelimiter
options.OutputRecordDelimiter = req.OutputSerialization.CSV.RecordDelimiter
options.OutputType = format.CSV
}
// Initialize JSON input type
s3s, err = json.New(options)
}
return s3s, err
}
// Execute is the function where all the blocking occurs, It writes to the HTTP
// response writer in a streaming fashion so that the client can actively use
// the results before the query is finally finished executing. The
func Execute(writer io.Writer, f format.Select) error {
rowCh := make(chan Row)
curBuf := bytes.NewBuffer(make([]byte, humanize.MiByte))
curBuf.Reset()
progressTicker := time.NewTicker(progressTime)
continuationTimer := time.NewTimer(continuationTime)
defer progressTicker.Stop()
defer continuationTimer.Stop()
go runSelectParser(f, rowCh)
for {
select {
case row, ok := <-rowCh:
if ok && row.err != nil {
_, err := writeErrorMessage(row.err, curBuf).WriteTo(writer)
flusher, okFlush := writer.(http.Flusher)
if okFlush {
flusher.Flush()
}
if err != nil {
return err
}
curBuf.Reset()
close(rowCh)
return nil
} else if ok {
_, err := writeRecordMessage(row.record, curBuf).WriteTo(writer)
flusher, okFlush := writer.(http.Flusher)
if okFlush {
flusher.Flush()
}
if err != nil {
return err
}
curBuf.Reset()
f.UpdateBytesReturned(int64(len(row.record)))
if !continuationTimer.Stop() {
<-continuationTimer.C
}
continuationTimer.Reset(continuationTime)
} else if !ok {
statPayload, err := f.CreateStatXML()
if err != nil {
return err
}
_, err = writeStatMessage(statPayload, curBuf).WriteTo(writer)
flusher, ok := writer.(http.Flusher)
if ok {
flusher.Flush()
}
if err != nil {
return err
}
curBuf.Reset()
_, err = writeEndMessage(curBuf).WriteTo(writer)
flusher, ok = writer.(http.Flusher)
if ok {
flusher.Flush()
}
if err != nil {
return err
}
return nil
}
case <-progressTicker.C:
// Send progress messages only if requested by client.
if f.Progress() {
progressPayload, err := f.CreateProgressXML()
if err != nil {
return err
}
_, err = writeProgressMessage(progressPayload, curBuf).WriteTo(writer)
flusher, ok := writer.(http.Flusher)
if ok {
flusher.Flush()
}
if err != nil {
return err
}
curBuf.Reset()
}
case <-continuationTimer.C:
_, err := writeContinuationMessage(curBuf).WriteTo(writer)
flusher, ok := writer.(http.Flusher)
if ok {
flusher.Flush()
}
if err != nil {
return err
}
curBuf.Reset()
continuationTimer.Reset(continuationTime)
}
}
}

95
pkg/s3select/json/args.go Normal file
View File

@ -0,0 +1,95 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"encoding/xml"
"fmt"
"strings"
)
const (
document = "document"
lines = "lines"
defaultRecordDelimiter = "\n"
)
// ReaderArgs - represents elements inside <InputSerialization><JSON/> in request XML.
type ReaderArgs struct {
ContentType string `xml:"Type"`
unmarshaled bool
}
// IsEmpty - returns whether reader args is empty or not.
func (args *ReaderArgs) IsEmpty() bool {
return !args.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (args *ReaderArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subReaderArgs ReaderArgs
parsedArgs := subReaderArgs{}
if err := d.DecodeElement(&parsedArgs, &start); err != nil {
return err
}
parsedArgs.ContentType = strings.ToLower(parsedArgs.ContentType)
switch parsedArgs.ContentType {
case document, lines:
default:
return errInvalidJSONType(fmt.Errorf("invalid ContentType '%v'", parsedArgs.ContentType))
}
*args = ReaderArgs(parsedArgs)
args.unmarshaled = true
return nil
}
// WriterArgs - represents elements inside <OutputSerialization><JSON/> in request XML.
type WriterArgs struct {
RecordDelimiter string `xml:"RecordDelimiter"`
unmarshaled bool
}
// IsEmpty - returns whether writer args is empty or not.
func (args *WriterArgs) IsEmpty() bool {
return !args.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (args *WriterArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subWriterArgs WriterArgs
parsedArgs := subWriterArgs{}
if err := d.DecodeElement(&parsedArgs, &start); err != nil {
return err
}
switch len(parsedArgs.RecordDelimiter) {
case 0:
parsedArgs.RecordDelimiter = defaultRecordDelimiter
case 1, 2:
default:
return fmt.Errorf("invalid RecordDelimiter '%v'", parsedArgs.RecordDelimiter)
}
*args = WriterArgs(parsedArgs)
args.unmarshaled = true
return nil
}

View File

@ -0,0 +1,62 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
type s3Error struct {
code string
message string
statusCode int
cause error
}
func (err *s3Error) Cause() error {
return err.cause
}
func (err *s3Error) ErrorCode() string {
return err.code
}
func (err *s3Error) ErrorMessage() string {
return err.message
}
func (err *s3Error) HTTPStatusCode() int {
return err.statusCode
}
func (err *s3Error) Error() string {
return err.message
}
func errInvalidJSONType(err error) *s3Error {
return &s3Error{
code: "InvalidJsonType",
message: "The JsonType is invalid. Only DOCUMENT and LINES are supported.",
statusCode: 400,
cause: err,
}
}
func errJSONParsingError(err error) *s3Error {
return &s3Error{
code: "JSONParsingError",
message: "Encountered an error parsing the JSON file. Check the file and try again.",
statusCode: 400,
cause: err,
}
}

217
pkg/s3select/json/reader.go Normal file
View File

@ -0,0 +1,217 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"bytes"
"io"
"io/ioutil"
"strconv"
"github.com/minio/minio/pkg/s3select/sql"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func toSingleLineJSON(input string, currentKey string, result gjson.Result) (output string, err error) {
switch {
case result.IsObject():
result.ForEach(func(key, value gjson.Result) bool {
jsonKey := key.String()
if currentKey != "" {
jsonKey = currentKey + "." + key.String()
}
output, err = toSingleLineJSON(input, jsonKey, value)
input = output
return err == nil
})
case result.IsArray():
i := 0
result.ForEach(func(key, value gjson.Result) bool {
if currentKey == "" {
panic("currentKey is empty")
}
indexKey := currentKey + "." + strconv.Itoa(i)
output, err = toSingleLineJSON(input, indexKey, value)
input = output
i++
return err == nil
})
default:
output, err = sjson.Set(input, currentKey, result.Value())
}
return output, err
}
type objectReader struct {
reader io.Reader
err error
p []byte
start int
end int
escaped bool
quoteOpened bool
curlyCount uint64
endOfObject bool
}
func (or *objectReader) objectEndIndex(p []byte, length int) int {
for i := 0; i < length; i++ {
if p[i] == '\\' {
or.escaped = !or.escaped
continue
}
if p[i] == '"' && !or.escaped {
or.quoteOpened = !or.quoteOpened
}
or.escaped = false
switch p[i] {
case '{':
if !or.quoteOpened {
or.curlyCount++
}
case '}':
if or.quoteOpened || or.curlyCount == 0 {
break
}
if or.curlyCount--; or.curlyCount == 0 {
return i + 1
}
}
}
return -1
}
func (or *objectReader) Read(p []byte) (n int, err error) {
if or.endOfObject {
return 0, io.EOF
}
if or.p != nil {
n = copy(p, or.p[or.start:or.end])
or.start += n
if or.start == or.end {
// made full copy.
or.p = nil
or.start = 0
or.end = 0
}
} else {
if or.err != nil {
return 0, or.err
}
n, err = or.reader.Read(p)
or.err = err
switch err {
case nil:
case io.EOF, io.ErrUnexpectedEOF, io.ErrClosedPipe:
or.err = io.EOF
default:
return 0, err
}
}
index := or.objectEndIndex(p, n)
if index == -1 || index == n {
return n, nil
}
or.endOfObject = true
if or.p == nil {
or.p = p
or.start = index
or.end = n
} else {
or.start -= index
}
return index, nil
}
func (or *objectReader) Reset() error {
or.endOfObject = false
if or.p != nil {
return nil
}
return or.err
}
// Reader - JSON record reader for S3Select.
type Reader struct {
args *ReaderArgs
objectReader *objectReader
readCloser io.ReadCloser
}
// Read - reads single record.
func (r *Reader) Read() (sql.Record, error) {
if err := r.objectReader.Reset(); err != nil {
return nil, err
}
data, err := ioutil.ReadAll(r.objectReader)
if err != nil {
return nil, errJSONParsingError(err)
}
data = bytes.TrimSpace(data)
if len(data) == 0 {
return nil, io.EOF
}
if !gjson.ValidBytes(data) {
return nil, errJSONParsingError(err)
}
if bytes.Count(data, []byte("\n")) > 0 {
var s string
if s, err = toSingleLineJSON("", "", gjson.ParseBytes(data)); err != nil {
return nil, errJSONParsingError(err)
}
data = []byte(s)
}
return &Record{
data: data,
}, nil
}
// Close - closes underlaying reader.
func (r *Reader) Close() error {
return r.readCloser.Close()
}
// NewReader - creates new JSON reader using readCloser.
func NewReader(readCloser io.ReadCloser, args *ReaderArgs) *Reader {
return &Reader{
args: args,
objectReader: &objectReader{reader: readCloser},
readCloser: readCloser,
}
}

107
pkg/s3select/json/record.go Normal file
View File

@ -0,0 +1,107 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package json
import (
"bytes"
"encoding/csv"
"fmt"
"strings"
"github.com/minio/minio/pkg/s3select/sql"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Record - is JSON record.
type Record struct {
data []byte
}
// Get - gets the value for a column name.
func (r *Record) Get(name string) (*sql.Value, error) {
result := gjson.GetBytes(r.data, name)
switch result.Type {
case gjson.False:
return sql.NewBool(false), nil
case gjson.Number:
return sql.NewFloat(result.Float()), nil
case gjson.String:
return sql.NewString(result.String()), nil
case gjson.True:
return sql.NewBool(true), nil
}
return nil, fmt.Errorf("unsupported gjson value %v; %v", result, result.Type)
}
// Set - sets the value for a column name.
func (r *Record) Set(name string, value *sql.Value) (err error) {
var v interface{}
switch value.Type() {
case sql.Bool:
v = value.BoolValue()
case sql.Int:
v = value.IntValue()
case sql.Float:
v = value.FloatValue()
case sql.String:
v = value.StringValue()
default:
return fmt.Errorf("unsupported sql value %v and type %v", value, value.Type())
}
name = strings.Replace(name, "*", "__ALL__", -1)
r.data, err = sjson.SetBytes(r.data, name, v)
return err
}
// MarshalCSV - encodes to CSV data.
func (r *Record) MarshalCSV(fieldDelimiter rune) ([]byte, error) {
var csvRecord []string
result := gjson.ParseBytes(r.data)
result.ForEach(func(key, value gjson.Result) bool {
csvRecord = append(csvRecord, value.String())
return true
})
buf := new(bytes.Buffer)
w := csv.NewWriter(buf)
w.Comma = fieldDelimiter
if err := w.Write(csvRecord); err != nil {
return nil, err
}
w.Flush()
if err := w.Error(); err != nil {
return nil, err
}
data := buf.Bytes()
return data[:len(data)-1], nil
}
// MarshalJSON - encodes to JSON data.
func (r *Record) MarshalJSON() ([]byte, error) {
return r.data, nil
}
// NewRecord - creates new empty JSON record.
func NewRecord() *Record {
return &Record{
data: []byte("{}"),
}
}

384
pkg/s3select/message.go Normal file
View File

@ -0,0 +1,384 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"bytes"
"encoding/binary"
"fmt"
"hash/crc32"
"net/http"
"strconv"
"sync/atomic"
"time"
)
// A message is in the format specified in
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-frame-overview.png
// hence the calculation is made accordingly.
func totalByteLength(headerLength, payloadLength int) int {
return 4 + 4 + 4 + headerLength + payloadLength + 4
}
func genMessage(header, payload []byte) []byte {
headerLength := len(header)
payloadLength := len(payload)
totalLength := totalByteLength(headerLength, payloadLength)
buf := new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(totalLength))
binary.Write(buf, binary.BigEndian, uint32(headerLength))
prelude := buf.Bytes()
binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE(prelude))
buf.Write(header)
if payload != nil {
buf.Write(payload)
}
message := buf.Bytes()
binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE(message))
return buf.Bytes()
}
// Refer genRecordsHeader().
var recordsHeader = []byte{
13, ':', 'm', 'e', 's', 's', 'a', 'g', 'e', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'e', 'v', 'e', 'n', 't',
13, ':', 'c', 'o', 'n', 't', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 24, 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n', '/', 'o', 'c', 't', 'e', 't', '-', 's', 't', 'r', 'e', 'a', 'm',
11, ':', 'e', 'v', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 7, 'R', 'e', 'c', 'o', 'r', 'd', 's',
}
// newRecordsMessage - creates new Records Message which can contain a single record, partial records,
// or multiple records. Depending on the size of the result, a response can contain one or more of these messages.
//
// Header specification
// Records messages contain three headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-record.png
//
// Payload specification
// Records message payloads can contain a single record, partial records, or multiple records.
func newRecordsMessage(payload []byte) []byte {
return genMessage(recordsHeader, payload)
}
// continuationMessage - S3 periodically sends this message to keep the TCP connection open.
// These messages appear in responses at random. The client must detect the message type and process accordingly.
//
// Header specification:
// Continuation messages contain two headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-cont.png
//
// Payload specification:
// Continuation messages have no payload.
var continuationMessage = []byte{
0, 0, 0, 57, // total byte-length.
0, 0, 0, 41, // headers byte-length.
139, 161, 157, 242, // prelude crc.
13, ':', 'm', 'e', 's', 's', 'a', 'g', 'e', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'e', 'v', 'e', 'n', 't', // headers.
11, ':', 'e', 'v', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 4, 'C', 'o', 'n', 't', // headers.
156, 134, 74, 13, // message crc.
}
// Refer genProgressHeader().
var progressHeader = []byte{
13, ':', 'm', 'e', 's', 's', 'a', 'g', 'e', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'e', 'v', 'e', 'n', 't',
13, ':', 'c', 'o', 'n', 't', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 8, 't', 'e', 'x', 't', '/', 'x', 'm', 'l',
11, ':', 'e', 'v', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 8, 'P', 'r', 'o', 'g', 'r', 'e', 's', 's',
}
// newProgressMessage - creates new Progress Message. S3 periodically sends this message, if requested.
// It contains information about the progress of a query that has started but has not yet completed.
//
// Header specification:
// Progress messages contain three headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-progress.png
//
// Payload specification:
// Progress message payload is an XML document containing information about the progress of a request.
// * BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed).
// * BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed).
// * BytesReturned => Current number of bytes of records payload data returned by S3.
//
// For uncompressed files, BytesScanned and BytesProcessed are equal.
//
// Example:
//
// <?xml version="1.0" encoding="UTF-8"?>
// <Progress>
// <BytesScanned>512</BytesScanned>
// <BytesProcessed>1024</BytesProcessed>
// <BytesReturned>1024</BytesReturned>
// </Progress>
//
func newProgressMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
payload := []byte(`<?xml version="1.0" encoding="UTF-8"?><Progress><BytesScanned>` +
strconv.FormatInt(bytesScanned, 10) + `</BytesScanned><BytesProcessed>` +
strconv.FormatInt(bytesProcessed, 10) + `</BytesProcessed><BytesReturned>` +
strconv.FormatInt(bytesReturned, 10) + `</BytesReturned></Stats>`)
return genMessage(progressHeader, payload)
}
// Refer genStatsHeader().
var statsHeader = []byte{
13, ':', 'm', 'e', 's', 's', 'a', 'g', 'e', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'e', 'v', 'e', 'n', 't',
13, ':', 'c', 'o', 'n', 't', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 8, 't', 'e', 'x', 't', '/', 'x', 'm', 'l',
11, ':', 'e', 'v', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'S', 't', 'a', 't', 's',
}
// newStatsMessage - creates new Stats Message. S3 sends this message at the end of the request.
// It contains statistics about the query.
//
// Header specification:
// Stats messages contain three headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-stats.png
//
// Payload specification:
// Stats message payload is an XML document containing information about a request's stats when processing is complete.
// * BytesScanned => Number of bytes that have been processed before being uncompressed (if the file is compressed).
// * BytesProcessed => Number of bytes that have been processed after being uncompressed (if the file is compressed).
// * BytesReturned => Total number of bytes of records payload data returned by S3.
//
// For uncompressed files, BytesScanned and BytesProcessed are equal.
//
// Example:
//
// <?xml version="1.0" encoding="UTF-8"?>
// <Stats>
// <BytesScanned>512</BytesScanned>
// <BytesProcessed>1024</BytesProcessed>
// <BytesReturned>1024</BytesReturned>
// </Stats>
func newStatsMessage(bytesScanned, bytesProcessed, bytesReturned int64) []byte {
payload := []byte(`<?xml version="1.0" encoding="UTF-8"?><Stats><BytesScanned>` +
strconv.FormatInt(bytesScanned, 10) + `</BytesScanned><BytesProcessed>` +
strconv.FormatInt(bytesProcessed, 10) + `</BytesProcessed><BytesReturned>` +
strconv.FormatInt(bytesReturned, 10) + `</BytesReturned></Stats>`)
return genMessage(statsHeader, payload)
}
// endMessage - indicates that the request is complete, and no more messages will be sent.
// You should not assume that the request is complete until the client receives an End message.
//
// Header specification:
// End messages contain two headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-end.png
//
// Payload specification:
// End messages have no payload.
var endMessage = []byte{
0, 0, 0, 56, // total byte-length.
0, 0, 0, 40, // headers byte-length.
193, 198, 132, 212, // prelude crc.
13, ':', 'm', 'e', 's', 's', 'a', 'g', 'e', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'e', 'v', 'e', 'n', 't', // headers.
11, ':', 'e', 'v', 'e', 'n', 't', '-', 't', 'y', 'p', 'e', 7, 0, 3, 'E', 'n', 'd', // headers.
207, 151, 211, 146, // message crc.
}
// newErrorMessage - creates new Request Level Error Message. S3 sends this message if the request failed for any reason.
// It contains the error code and error message for the failure. If S3 sends a RequestLevelError message,
// it doesn't send an End message.
//
// Header specification:
// Request-level error messages contain three headers, as follows:
// https://docs.aws.amazon.com/AmazonS3/latest/API/images/s3select-frame-diagram-error.png
//
// Payload specification:
// Request-level error messages have no payload.
func newErrorMessage(errorCode, errorMessage []byte) []byte {
buf := new(bytes.Buffer)
buf.Write([]byte{13, ':', 'm', 'e', 's', 's', 'a', 'g', 'e', '-', 't', 'y', 'p', 'e', 7, 0, 5, 'e', 'r', 'r', 'o', 'r'})
buf.Write([]byte{14, ':', 'e', 'r', 'r', 'o', 'r', '-', 'm', 'e', 's', 's', 'a', 'g', 'e', 7})
binary.Write(buf, binary.BigEndian, uint16(len(errorMessage)))
buf.Write(errorMessage)
buf.Write([]byte{11, ':', 'e', 'r', 'r', 'o', 'r', '-', 'c', 'o', 'd', 'e', 7})
binary.Write(buf, binary.BigEndian, uint16(len(errorCode)))
buf.Write(errorCode)
return genMessage(buf.Bytes(), nil)
}
// NewErrorMessage - creates new Request Level Error Message specified in
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html.
func NewErrorMessage(errorCode, errorMessage string) []byte {
return newErrorMessage([]byte(errorCode), []byte(errorMessage))
}
type messageWriter struct {
writer http.ResponseWriter
getProgressFunc func() (int64, int64)
bytesReturned int64
dataCh chan []byte
doneCh chan struct{}
closeCh chan struct{}
stopped uint32
closed uint32
}
func (writer *messageWriter) write(data []byte) bool {
if _, err := writer.writer.Write(data); err != nil {
return false
}
writer.writer.(http.Flusher).Flush()
return true
}
func (writer *messageWriter) start() {
keepAliveTicker := time.NewTicker(1 * time.Second)
var progressTicker *time.Ticker
if writer.getProgressFunc != nil {
progressTicker = time.NewTicker(1 * time.Minute)
}
go func() {
quitFlag := 0
for quitFlag == 0 {
if progressTicker == nil {
select {
case data, ok := <-writer.dataCh:
if !ok {
quitFlag = 1
break
}
if !writer.write(data) {
quitFlag = 1
}
case <-writer.doneCh:
quitFlag = 2
case <-keepAliveTicker.C:
if !writer.write(continuationMessage) {
quitFlag = 1
}
}
} else {
select {
case data, ok := <-writer.dataCh:
if !ok {
quitFlag = 1
break
}
if !writer.write(data) {
quitFlag = 1
}
case <-writer.doneCh:
quitFlag = 2
case <-keepAliveTicker.C:
if !writer.write(continuationMessage) {
quitFlag = 1
}
case <-progressTicker.C:
bytesScanned, bytesProcessed := writer.getProgressFunc()
bytesReturned := atomic.LoadInt64(&writer.bytesReturned)
if !writer.write(newProgressMessage(bytesScanned, bytesProcessed, bytesReturned)) {
quitFlag = 1
}
}
}
}
atomic.StoreUint32(&writer.stopped, 1)
close(writer.closeCh)
keepAliveTicker.Stop()
if progressTicker != nil {
progressTicker.Stop()
}
if quitFlag == 2 {
for data := range writer.dataCh {
if _, err := writer.writer.Write(data); err != nil {
break
}
}
}
}()
}
func (writer *messageWriter) close() {
if atomic.SwapUint32(&writer.closed, 1) == 0 {
close(writer.doneCh)
for range writer.closeCh {
close(writer.dataCh)
}
}
}
func (writer *messageWriter) send(data []byte) error {
err := func() error {
if atomic.LoadUint32(&writer.stopped) == 1 {
return fmt.Errorf("writer already closed")
}
select {
case writer.dataCh <- data:
case <-writer.doneCh:
return fmt.Errorf("closed writer")
}
return nil
}()
if err != nil {
writer.close()
}
return err
}
func (writer *messageWriter) SendRecords(payload []byte) error {
err := writer.send(newRecordsMessage(payload))
if err == nil {
atomic.AddInt64(&writer.bytesReturned, int64(len(payload)))
}
return err
}
func (writer *messageWriter) SendStats(bytesScanned, bytesProcessed int64) error {
bytesReturned := atomic.LoadInt64(&writer.bytesReturned)
err := writer.send(newStatsMessage(bytesScanned, bytesProcessed, bytesReturned))
if err != nil {
return err
}
err = writer.send(endMessage)
writer.close()
return err
}
func (writer *messageWriter) SendError(errorCode, errorMessage string) error {
err := writer.send(newErrorMessage([]byte(errorCode), []byte(errorMessage)))
if err == nil {
writer.close()
}
return err
}
func newMessageWriter(w http.ResponseWriter, getProgressFunc func() (bytesScanned, bytesProcessed int64)) *messageWriter {
writer := &messageWriter{
writer: w,
getProgressFunc: getProgressFunc,
dataCh: make(chan []byte),
doneCh: make(chan struct{}),
closeCh: make(chan struct{}),
}
writer.start()
return writer
}

View File

@ -1,460 +0,0 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// DO NOT EDIT THIS PACKAGE DIRECTLY: This follows the protocol defined by
// AmazonS3 found at
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
// Consult the Spec before making direct edits.
package s3select
import (
"bytes"
"encoding/binary"
"hash/crc32"
)
// Record Headers
// -11 -event type - 7 - 7 "Records"
// -13 -content-type -7 -24 "application/octet-stream"
// -13 -message-type -7 5 "event"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var recordHeaders []byte
// End Headers
// -13 -message-type -7 -5 "event"
// -11 -:event-type -7 -3 "End"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var endHeaders []byte
// Continuation Headers
// -13 -message-type -7 -5 "event"
// -11 -:event-type -7 -4 "Cont"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var contHeaders []byte
// Stat Headers
// -11 -event type - 7 - 5 "Stat" -20
// -13 -content-type -7 -8 "text/xml" -25
// -13 -message-type -7 -5 "event" -22
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var statHeaders []byte
// Progress Headers
// -11 -event type - 7 - 8 "Progress" -23
// -13 -content-type -7 -8 "text/xml" -25
// -13 -message-type -7 -5 "event" -22
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var progressHeaders []byte
// The length of the nonvariable portion of the ErrHeaders
// The below are the specifications of the header for a "error" event
// -11 -error-code - 7 - DEFINED "DEFINED"
// -14 -error-message -7 -DEFINED "DEFINED"
// -13 -message-type -7 -5 "error"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var errHdrLen int
func init() {
recordHeaders = writeRecordHeader()
endHeaders = writeEndHeader()
contHeaders = writeContHeader()
statHeaders = writeStatHeader()
progressHeaders = writeProgressHeader()
errHdrLen = 55
}
// encodeString encodes a string in a []byte, lenBytes is the number of bytes
// used to encode the length of the string.
func encodeHeaderStringValue(s string) []byte {
n := uint16(len(s))
lenSlice := make([]byte, 2)
binary.BigEndian.PutUint16(lenSlice[0:], n)
return append(lenSlice, []byte(s)...)
}
func encodeHeaderStringName(s string) []byte {
lenSlice := make([]byte, 1)
lenSlice[0] = byte(len(s))
return append(lenSlice, []byte(s)...)
}
// encodeNumber encodes a number in a []byte, lenBytes is the number of bytes
// used to encode the length of the string.
func encodeNumber(n byte, lenBytes int) []byte {
lenSlice := make([]byte, lenBytes)
lenSlice[0] = n
return lenSlice
}
// writePayloadSize writes the 4byte payload size portion of the protocol.
func writePayloadSize(payloadSize int, headerLength int) []byte {
totalByteLen := make([]byte, 4)
totalMsgLen := uint32(payloadSize + headerLength + 16)
binary.BigEndian.PutUint32(totalByteLen, totalMsgLen)
return totalByteLen
}
// writeHeaderSize writes the 4byte header size portion of the protocol.
func writeHeaderSize(headerLength int) []byte {
totalHeaderLen := make([]byte, 4)
totalLen := uint32(headerLength)
binary.BigEndian.PutUint32(totalHeaderLen, totalLen)
return totalHeaderLen
}
// writeCRC writes the CRC for both the prelude and and the end of the protocol.
func writeCRC(buffer []byte) []byte {
// Calculate the CRC here:
crc := make([]byte, 4)
cksum := crc32.ChecksumIEEE(buffer)
binary.BigEndian.PutUint32(crc, cksum)
return crc
}
// writePayload writes the Payload for those protocols which the Payload is
// necessary.
func writePayload(myPayload string) []byte {
convertedPayload := []byte(myPayload)
payloadStore := make([]byte, len(convertedPayload))
copy(payloadStore[0:], myPayload)
return payloadStore
}
// writeRecordHeader is a function which writes the headers for the continuation
// Message
func writeRecordHeader() []byte {
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var currentMessage = &bytes.Buffer{}
// 11 -event type - 7 - 7 "Records"
// header name
currentMessage.Write(encodeHeaderStringName(":event-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("Records"))
// Creation of the Header for Content-Type // 13 -content-type -7 -24
// "application/octet-stream"
// header name
currentMessage.Write(encodeHeaderStringName(":content-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("application/octet-stream"))
// Creation of the Header for message-type 13 -message-type -7 5 "event"
// header name
currentMessage.Write(encodeHeaderStringName(":message-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("event"))
return currentMessage.Bytes()
}
// writeEndHeader is a function which writes the headers for the continuation
// Message
func writeEndHeader() []byte {
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var currentMessage = &bytes.Buffer{}
// header name
currentMessage.Write(encodeHeaderStringName(":event-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("End"))
// Creation of the Header for message-type 13 -message-type -7 5 "event"
// header name
currentMessage.Write(encodeHeaderStringName(":message-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("event"))
return currentMessage.Bytes()
}
// writeContHeader is a function which writes the headers for the continuation
// Message
func writeContHeader() []byte {
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var currentMessage = &bytes.Buffer{}
// header name
currentMessage.Write(encodeHeaderStringName(":event-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("Cont"))
// Creation of the Header for message-type 13 -message-type -7 5 "event"
// header name
currentMessage.Write(encodeHeaderStringName(":message-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("event"))
return currentMessage.Bytes()
}
// writeStatHeader is a function which writes the headers for the Stat
// Message
func writeStatHeader() []byte {
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var currentMessage = &bytes.Buffer{}
// header name
currentMessage.Write(encodeHeaderStringName(":event-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("Stats"))
// Creation of the Header for Content-Type // 13 -content-type -7 -8
// "text/xml"
// header name
currentMessage.Write(encodeHeaderStringName(":content-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("text/xml"))
// Creation of the Header for message-type 13 -message-type -7 5 "event"
currentMessage.Write(encodeHeaderStringName(":message-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("event"))
return currentMessage.Bytes()
}
// writeProgressHeader is a function which writes the headers for the Progress
// Message
func writeProgressHeader() []byte {
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
var currentMessage = &bytes.Buffer{}
// header name
currentMessage.Write(encodeHeaderStringName(":event-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("Progress"))
// Creation of the Header for Content-Type // 13 -content-type -7 -8
// "text/xml"
// header name
currentMessage.Write(encodeHeaderStringName(":content-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("text/xml"))
// Creation of the Header for message-type 13 -message-type -7 5 "event"
// header name
currentMessage.Write(encodeHeaderStringName(":message-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("event"))
return currentMessage.Bytes()
}
// writeRecordMessage is the function which constructs the binary message for a
// record message to be sent.
func writeRecordMessage(payload string, currentMessage *bytes.Buffer) *bytes.Buffer {
// The below are the specifications of the header for a "record" event
// 11 -event type - 7 - 7 "Records"
// 13 -content-type -7 -24 "application/octet-stream"
// 13 -message-type -7 5 "event"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
headerLen := len(recordHeaders)
// Writes the total size of the message.
currentMessage.Write(writePayloadSize(len(payload), headerLen))
// Writes the total size of the header.
currentMessage.Write(writeHeaderSize(headerLen))
// Writes the CRC of the Prelude
currentMessage.Write(writeCRC(currentMessage.Bytes()))
currentMessage.Write(recordHeaders)
// This part is where the payload is written, this will be only one row, since
// we're sending one message at a types
currentMessage.Write(writePayload(payload))
// Now we do a CRC check on the entire messages
currentMessage.Write(writeCRC(currentMessage.Bytes()))
return currentMessage
}
// writeContinuationMessage is the function which constructs the binary message
// for a continuation message to be sent.
func writeContinuationMessage(currentMessage *bytes.Buffer) *bytes.Buffer {
// 11 -event type - 7 - 4 "Cont"
// 13 -message-type -7 5 "event"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
headerLen := len(contHeaders)
currentMessage.Write(writePayloadSize(0, headerLen))
currentMessage.Write(writeHeaderSize(headerLen))
// Calculate the Prelude CRC here:
currentMessage.Write(writeCRC(currentMessage.Bytes()))
currentMessage.Write(contHeaders)
//Now we do a CRC check on the entire messages
currentMessage.Write(writeCRC(currentMessage.Bytes()))
return currentMessage
}
// writeEndMessage is the function which constructs the binary message
// for a end message to be sent.
func writeEndMessage(currentMessage *bytes.Buffer) *bytes.Buffer {
// 11 -event type - 7 - 3 "End"
// 13 -message-type -7 5 "event"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
headerLen := len(endHeaders)
currentMessage.Write(writePayloadSize(0, headerLen))
currentMessage.Write(writeHeaderSize(headerLen))
//Calculate the Prelude CRC here:
currentMessage.Write(writeCRC(currentMessage.Bytes()))
currentMessage.Write(endHeaders)
// Now we do a CRC check on the entire messages
currentMessage.Write(writeCRC(currentMessage.Bytes()))
return currentMessage
}
// writeStateMessage is the function which constructs the binary message for a
// state message to be sent.
func writeStatMessage(payload string, currentMessage *bytes.Buffer) *bytes.Buffer {
// 11 -event type - 7 - 5 "Stat" 20
// 13 -content-type -7 -8 "text/xml" 25
// 13 -message-type -7 5 "event" 22
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
headerLen := len(statHeaders)
currentMessage.Write(writePayloadSize(len(payload), headerLen))
currentMessage.Write(writeHeaderSize(headerLen))
currentMessage.Write(writeCRC(currentMessage.Bytes()))
currentMessage.Write(statHeaders)
// This part is where the payload is written, this will be only one row, since
// we're sending one message at a types
currentMessage.Write(writePayload(payload))
// Now we do a CRC check on the entire messages
currentMessage.Write(writeCRC(currentMessage.Bytes()))
return currentMessage
}
// writeProgressMessage is the function which constructs the binary message for
// a progress message to be sent.
func writeProgressMessage(payload string, currentMessage *bytes.Buffer) *bytes.Buffer {
// The below are the specifications of the header for a "Progress" event
// 11 -event type - 7 - 8 "Progress" 23
// 13 -content-type -7 -8 "text/xml" 25
// 13 -message-type -7 5 "event" 22
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
headerLen := len(progressHeaders)
currentMessage.Write(writePayloadSize(len(payload), headerLen))
currentMessage.Write(writeHeaderSize(headerLen))
currentMessage.Write(writeCRC(currentMessage.Bytes()))
currentMessage.Write(progressHeaders)
// This part is where the payload is written, this will be only one row, since
// we're sending one message at a types
currentMessage.Write(writePayload(payload))
// Now we do a CRC check on the entire messages
currentMessage.Write(writeCRC(currentMessage.Bytes()))
return currentMessage
}
// writeErrorMessage is the function which constructs the binary message for a
// error message to be sent.
func writeErrorMessage(errorMessage error, currentMessage *bytes.Buffer) *bytes.Buffer {
// The below are the specifications of the header for a "error" event
// 11 -error-code - 7 - DEFINED "DEFINED"
// 14 -error-message -7 -DEFINED "DEFINED"
// 13 -message-type -7 5 "error"
// This is predefined from AMZ protocol found here:
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
sizeOfErrorCode := len(errorCodeResponse[errorMessage])
sizeOfErrorMessage := len(errorMessage.Error())
headerLen := errHdrLen + sizeOfErrorCode + sizeOfErrorMessage
currentMessage.Write(writePayloadSize(0, headerLen))
currentMessage.Write(writeHeaderSize(headerLen))
currentMessage.Write(writeCRC(currentMessage.Bytes()))
// header name
currentMessage.Write(encodeHeaderStringName(":error-code"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue(errorCodeResponse[errorMessage]))
// 14 -error-message -7 -DEFINED "DEFINED"
// header name
currentMessage.Write(encodeHeaderStringName(":error-message"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue(errorMessage.Error()))
// Creation of the Header for message-type 13 -message-type -7 5 "error"
// header name
currentMessage.Write(encodeHeaderStringName(":message-type"))
// header type
currentMessage.Write(encodeNumber(7, 1))
// header value and header value length
currentMessage.Write(encodeHeaderStringValue("error"))
// Now we do a CRC check on the entire messages
currentMessage.Write(writeCRC(currentMessage.Bytes()))
return currentMessage
}

View File

@ -0,0 +1,42 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import "encoding/xml"
// ReaderArgs - represents elements inside <InputSerialization><Parquet/> in request XML.
type ReaderArgs struct {
unmarshaled bool
}
// IsEmpty - returns whether reader args is empty or not.
func (args *ReaderArgs) IsEmpty() bool {
return !args.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (args *ReaderArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subReaderArgs ReaderArgs
parsedArgs := subReaderArgs{}
if err := d.DecodeElement(&parsedArgs, &start); err != nil {
return err
}
args.unmarshaled = true
return nil
}

View File

@ -0,0 +1,53 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
type s3Error struct {
code string
message string
statusCode int
cause error
}
func (err *s3Error) Cause() error {
return err.cause
}
func (err *s3Error) ErrorCode() string {
return err.code
}
func (err *s3Error) ErrorMessage() string {
return err.message
}
func (err *s3Error) HTTPStatusCode() int {
return err.statusCode
}
func (err *s3Error) Error() string {
return err.message
}
func errParquetParsingError(err error) *s3Error {
return &s3Error{
code: "ParquetParsingError",
message: "Error parsing Parquet file. Please check the file and try again.",
statusCode: 400,
cause: err,
}
}

View File

@ -0,0 +1,93 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"io"
"github.com/minio/minio/pkg/s3select/json"
"github.com/minio/minio/pkg/s3select/sql"
parquetgo "github.com/minio/parquet-go"
parquetgen "github.com/minio/parquet-go/gen-go/parquet"
)
// Reader - Parquet record reader for S3Select.
type Reader struct {
args *ReaderArgs
file *parquetgo.File
}
// Read - reads single record.
func (r *Reader) Read() (sql.Record, error) {
parquetRecord, err := r.file.Read()
if err != nil {
if err != io.EOF {
return nil, errParquetParsingError(err)
}
return nil, err
}
record := json.NewRecord()
for name, v := range parquetRecord {
var value *sql.Value
switch v.Type {
case parquetgen.Type_BOOLEAN:
value = sql.NewBool(v.Value.(bool))
case parquetgen.Type_INT32:
value = sql.NewInt(int64(v.Value.(int32)))
case parquetgen.Type_INT64:
value = sql.NewInt(v.Value.(int64))
case parquetgen.Type_FLOAT:
value = sql.NewFloat(float64(v.Value.(float32)))
case parquetgen.Type_DOUBLE:
value = sql.NewFloat(v.Value.(float64))
case parquetgen.Type_INT96, parquetgen.Type_BYTE_ARRAY, parquetgen.Type_FIXED_LEN_BYTE_ARRAY:
value = sql.NewString(string(v.Value.([]byte)))
default:
return nil, errParquetParsingError(nil)
}
if err = record.Set(name, value); err != nil {
return nil, errParquetParsingError(err)
}
}
return record, nil
}
// Close - closes underlaying readers.
func (r *Reader) Close() error {
return r.file.Close()
}
// NewReader - creates new Parquet reader using readerFunc callback.
func NewReader(getReaderFunc func(offset, length int64) (io.ReadCloser, error), args *ReaderArgs) (*Reader, error) {
file, err := parquetgo.Open(getReaderFunc, nil)
if err != nil {
if err != io.EOF {
return nil, errParquetParsingError(err)
}
return nil, err
}
return &Reader{
args: args,
file: file,
}, nil
}

90
pkg/s3select/progress.go Normal file
View File

@ -0,0 +1,90 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"compress/bzip2"
"fmt"
"io"
"sync/atomic"
gzip "github.com/klauspost/pgzip"
)
type countUpReader struct {
reader io.Reader
bytesRead int64
}
func (r *countUpReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
atomic.AddInt64(&r.bytesRead, int64(n))
return n, err
}
func (r *countUpReader) BytesRead() int64 {
return atomic.LoadInt64(&r.bytesRead)
}
func newCountUpReader(reader io.Reader) *countUpReader {
return &countUpReader{
reader: reader,
}
}
type progressReader struct {
rc io.ReadCloser
scannedReader *countUpReader
processedReader *countUpReader
}
func (pr *progressReader) Read(p []byte) (n int, err error) {
return pr.processedReader.Read(p)
}
func (pr *progressReader) Close() error {
return pr.rc.Close()
}
func (pr *progressReader) Stats() (bytesScanned, bytesProcessed int64) {
return pr.scannedReader.BytesRead(), pr.processedReader.BytesRead()
}
func newProgressReader(rc io.ReadCloser, compType CompressionType) (*progressReader, error) {
scannedReader := newCountUpReader(rc)
var r io.Reader
var err error
switch compType {
case noneType:
r = scannedReader
case gzipType:
if r, err = gzip.NewReader(scannedReader); err != nil {
return nil, errTruncatedInput(err)
}
case bzip2Type:
r = bzip2.NewReader(scannedReader)
default:
return nil, errInvalidCompressionFormat(fmt.Errorf("unknown compression type '%v'", compType))
}
return &progressReader{
rc: rc,
scannedReader: scannedReader,
processedReader: newCountUpReader(r),
}, nil
}

View File

@ -1,5 +1,5 @@
/* /*
* Minio Cloud Storage, (C) 2018 Minio, Inc. * Minio Cloud Storage, (C) 2019 Minio, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,436 +17,377 @@
package s3select package s3select
import ( import (
"math" "encoding/xml"
"sort" "fmt"
"strconv" "io"
"net/http"
"strings" "strings"
"github.com/minio/minio/pkg/s3select/format" "github.com/minio/minio/pkg/s3select/csv"
"github.com/tidwall/gjson" "github.com/minio/minio/pkg/s3select/json"
"github.com/xwb1989/sqlparser" "github.com/minio/minio/pkg/s3select/parquet"
"github.com/minio/minio/pkg/s3select/sql"
) )
// SelectFuncs contains the relevant values from the parser for S3 Select type recordReader interface {
// Functions Read() (sql.Record, error)
type SelectFuncs struct { Close() error
funcExpr []*sqlparser.FuncExpr
index []int
} }
// RunSqlParser allows us to easily bundle all the functions from above and run const (
// them in the appropriate order. csvFormat = "csv"
func runSelectParser(f format.Select, rowCh chan Row) { jsonFormat = "json"
reqCols, alias, limit, wc, aggFunctionNames, fns, err := ParseSelect(f) parquetFormat = "parquet"
)
// CompressionType - represents value inside <CompressionType/> in request XML.
type CompressionType string
const (
noneType CompressionType = "none"
gzipType CompressionType = "gzip"
bzip2Type CompressionType = "bzip2"
)
// UnmarshalXML - decodes XML data.
func (c *CompressionType) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
var s string
if err := d.DecodeElement(&s, &start); err != nil {
return errMalformedXML(err)
}
parsedType := CompressionType(strings.ToLower(s))
if s == "" {
parsedType = noneType
}
switch parsedType {
case noneType, gzipType, bzip2Type:
default:
return errInvalidCompressionFormat(fmt.Errorf("invalid compression format '%v'", s))
}
*c = parsedType
return nil
}
// InputSerialization - represents elements inside <InputSerialization/> in request XML.
type InputSerialization struct {
CompressionType CompressionType `xml:"CompressionType"`
CSVArgs csv.ReaderArgs `xml:"CSV"`
JSONArgs json.ReaderArgs `xml:"JSON"`
ParquetArgs parquet.ReaderArgs `xml:"Parquet"`
unmarshaled bool
format string
}
// IsEmpty - returns whether input serialization is empty or not.
func (input *InputSerialization) IsEmpty() bool {
return !input.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (input *InputSerialization) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subInputSerialization InputSerialization
parsedInput := subInputSerialization{}
if err := d.DecodeElement(&parsedInput, &start); err != nil {
return errMalformedXML(err)
}
found := 0
if !parsedInput.CSVArgs.IsEmpty() {
parsedInput.format = csvFormat
found++
}
if !parsedInput.JSONArgs.IsEmpty() {
parsedInput.format = jsonFormat
found++
}
if !parsedInput.ParquetArgs.IsEmpty() {
if parsedInput.CompressionType != noneType {
return errInvalidRequestParameter(fmt.Errorf("CompressionType must be NONE for Parquet format"))
}
parsedInput.format = parquetFormat
found++
}
if found != 1 {
return errInvalidDataSource(nil)
}
*input = InputSerialization(parsedInput)
input.unmarshaled = true
return nil
}
// OutputSerialization - represents elements inside <OutputSerialization/> in request XML.
type OutputSerialization struct {
CSVArgs csv.WriterArgs `xml:"CSV"`
JSONArgs json.WriterArgs `xml:"JSON"`
unmarshaled bool
format string
}
// IsEmpty - returns whether output serialization is empty or not.
func (output *OutputSerialization) IsEmpty() bool {
return !output.unmarshaled
}
// UnmarshalXML - decodes XML data.
func (output *OutputSerialization) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subOutputSerialization OutputSerialization
parsedOutput := subOutputSerialization{}
if err := d.DecodeElement(&parsedOutput, &start); err != nil {
return errMalformedXML(err)
}
found := 0
if !parsedOutput.CSVArgs.IsEmpty() {
parsedOutput.format = csvFormat
found++
}
if !parsedOutput.JSONArgs.IsEmpty() {
parsedOutput.format = jsonFormat
found++
}
if found != 1 {
return errObjectSerializationConflict(fmt.Errorf("either CSV or JSON should be present in OutputSerialization"))
}
*output = OutputSerialization(parsedOutput)
output.unmarshaled = true
return nil
}
// RequestProgress - represents elements inside <RequestProgress/> in request XML.
type RequestProgress struct {
Enabled bool `xml:"Enabled"`
}
// S3Select - filters the contents on a simple structured query language (SQL) statement. It
// represents elements inside <SelectObjectContentRequest/> in request XML specified in detail at
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html.
type S3Select struct {
XMLName xml.Name `xml:"SelectObjectContentRequest"`
Expression string `xml:"Expression"`
ExpressionType string `xml:"ExpressionType"`
Input InputSerialization `xml:"InputSerialization"`
Output OutputSerialization `xml:"OutputSerialization"`
Progress RequestProgress `xml:"RequestProgress"`
statement *sql.Select
progressReader *progressReader
recordReader recordReader
}
// UnmarshalXML - decodes XML data.
func (s3Select *S3Select) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// Make subtype to avoid recursive UnmarshalXML().
type subS3Select S3Select
parsedS3Select := subS3Select{}
if err := d.DecodeElement(&parsedS3Select, &start); err != nil {
if _, ok := err.(*s3Error); ok {
return err
}
return errMalformedXML(err)
}
parsedS3Select.ExpressionType = strings.ToLower(parsedS3Select.ExpressionType)
if parsedS3Select.ExpressionType != "sql" {
return errInvalidExpressionType(fmt.Errorf("invalid expression type '%v'", parsedS3Select.ExpressionType))
}
if parsedS3Select.Input.IsEmpty() {
return errMissingRequiredParameter(fmt.Errorf("InputSerialization must be provided"))
}
if parsedS3Select.Output.IsEmpty() {
return errMissingRequiredParameter(fmt.Errorf("OutputSerialization must be provided"))
}
statement, err := sql.NewSelect(parsedS3Select.Expression)
if err != nil { if err != nil {
rowCh <- Row{ return err
err: err,
}
return
}
processSelectReq(reqCols, alias, wc, limit, aggFunctionNames, rowCh, fns, f)
}
// ParseSelect parses the SELECT expression, and effectively tokenizes it into
// its separate parts. It returns the requested column names,alias,limit of
// records, and the where clause.
func ParseSelect(f format.Select) ([]string, string, int64, sqlparser.Expr, []string, SelectFuncs, error) {
var sFuncs = SelectFuncs{}
var whereClause sqlparser.Expr
var alias string
var limit int64
stmt, err := sqlparser.Parse(f.Expression())
// TODO: Maybe can parse their errors a bit to return some more of the s3 errors
if err != nil {
return nil, "", 0, nil, nil, sFuncs, ErrLexerInvalidChar
} }
switch stmt := stmt.(type) { parsedS3Select.statement = statement
case *sqlparser.Select:
// evaluates the where clause
fnNames := make([]string, len(stmt.SelectExprs))
columnNames := make([]string, len(stmt.SelectExprs))
if stmt.Where != nil { *s3Select = S3Select(parsedS3Select)
whereClause = stmt.Where.Expr return nil
} }
for i, sexpr := range stmt.SelectExprs {
switch expr := sexpr.(type) {
case *sqlparser.StarExpr:
columnNames[0] = "*"
case *sqlparser.AliasedExpr:
switch smallerexpr := expr.Expr.(type) {
case *sqlparser.FuncExpr:
if smallerexpr.IsAggregate() {
fnNames[i] = smallerexpr.Name.CompliantName()
// Will return function name
// Case to deal with if we have functions and not an asterix
switch tempagg := smallerexpr.Exprs[0].(type) {
case *sqlparser.StarExpr:
columnNames[0] = "*"
if smallerexpr.Name.CompliantName() != "count" {
return nil, "", 0, nil, nil, sFuncs, ErrParseUnsupportedCallWithStar
}
case *sqlparser.AliasedExpr:
switch col := tempagg.Expr.(type) {
case *sqlparser.BinaryExpr:
return nil, "", 0, nil, nil, sFuncs, ErrParseNonUnaryAgregateFunctionCall
case *sqlparser.ColName:
columnNames[i] = col.Name.CompliantName()
}
}
// Case to deal with if COALESCE was used..
} else if supportedFunc(smallerexpr.Name.CompliantName()) {
if sFuncs.funcExpr == nil {
sFuncs.funcExpr = make([]*sqlparser.FuncExpr, len(stmt.SelectExprs))
sFuncs.index = make([]int, len(stmt.SelectExprs))
}
sFuncs.funcExpr[i] = smallerexpr
sFuncs.index[i] = i
} else {
return nil, "", 0, nil, nil, sFuncs, ErrUnsupportedSQLOperation
}
case *sqlparser.ColName:
columnNames[i] = smallerexpr.Name.CompliantName()
}
}
}
// This code retrieves the alias and makes sure it is set to the correct func (s3Select *S3Select) outputRecord() sql.Record {
// value, if not it sets it to the tablename switch s3Select.Output.format {
for _, fexpr := range stmt.From { case csvFormat:
switch smallerexpr := fexpr.(type) { return csv.NewRecord()
case *sqlparser.JoinTableExpr: case jsonFormat:
return nil, "", 0, nil, nil, sFuncs, ErrParseMalformedJoin return json.NewRecord()
case *sqlparser.AliasedTableExpr:
alias = smallerexpr.As.CompliantName()
if alias == "" {
alias = sqlparser.GetTableName(smallerexpr.Expr).CompliantName()
}
}
}
if stmt.Limit != nil {
switch expr := stmt.Limit.Rowcount.(type) {
case *sqlparser.SQLVal:
// The Value of how many rows we're going to limit by
parsedLimit, _ := strconv.Atoi(string(expr.Val[:]))
limit = int64(parsedLimit)
}
}
if stmt.GroupBy != nil {
return nil, "", 0, nil, nil, sFuncs, ErrParseUnsupportedLiteralsGroupBy
}
if stmt.OrderBy != nil {
return nil, "", 0, nil, nil, sFuncs, ErrParseUnsupportedToken
}
if err := parseErrs(columnNames, whereClause, alias, sFuncs, f); err != nil {
return nil, "", 0, nil, nil, sFuncs, err
}
return columnNames, alias, limit, whereClause, fnNames, sFuncs, nil
} }
return nil, "", 0, nil, nil, sFuncs, nil
panic(fmt.Errorf("unknown output format '%v'", s3Select.Output.format))
} }
type columnKv struct { func (s3Select *S3Select) getProgress() (bytesScanned, bytesProcessed int64) {
Key string if s3Select.progressReader != nil {
Value int return s3Select.progressReader.Stats()
}
return -1, -1
} }
func columnsIndex(reqColNames []string, f format.Select) ([]columnKv, error) { // Open - opens S3 object by using callback for SQL selection query.
var ( // Currently CSV, JSON and Apache Parquet formats are supported.
columnsKv []columnKv func (s3Select *S3Select) Open(getReader func(offset, length int64) (io.ReadCloser, error)) error {
columnsMap = make(map[string]int) switch s3Select.Input.format {
columns = f.Header() case csvFormat:
) rc, err := getReader(0, -1)
if f.HasHeader() { if err != nil {
err := checkForDuplicates(columns, columnsMap) return err
if format.IsInt(reqColNames[0]) {
err = ErrMissingHeaders
} }
s3Select.progressReader, err = newProgressReader(rc, s3Select.Input.CompressionType)
if err != nil {
return err
}
s3Select.recordReader, err = csv.NewReader(s3Select.progressReader, &s3Select.Input.CSVArgs)
if err != nil {
return err
}
return nil
case jsonFormat:
rc, err := getReader(0, -1)
if err != nil {
return err
}
s3Select.progressReader, err = newProgressReader(rc, s3Select.Input.CompressionType)
if err != nil {
return err
}
s3Select.recordReader = json.NewReader(s3Select.progressReader, &s3Select.Input.JSONArgs)
return nil
case parquetFormat:
var err error
s3Select.recordReader, err = parquet.NewReader(getReader, &s3Select.Input.ParquetArgs)
return err
}
panic(fmt.Errorf("unknown input format '%v'", s3Select.Input.format))
}
func (s3Select *S3Select) marshal(record sql.Record) ([]byte, error) {
switch s3Select.Output.format {
case csvFormat:
data, err := record.MarshalCSV([]rune(s3Select.Output.CSVArgs.FieldDelimiter)[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
for k, v := range columnsMap {
columnsKv = append(columnsKv, columnKv{ return append(data, []byte(s3Select.Output.CSVArgs.RecordDelimiter)...), nil
Key: k, case jsonFormat:
Value: v, data, err := record.MarshalJSON()
}) if err != nil {
} return nil, err
} else {
for i := range columns {
columnsKv = append(columnsKv, columnKv{
Key: "_" + strconv.Itoa(i),
Value: i,
})
} }
return append(data, []byte(s3Select.Output.JSONArgs.RecordDelimiter)...), nil
} }
sort.Slice(columnsKv, func(i, j int) bool {
return columnsKv[i].Value < columnsKv[j].Value panic(fmt.Errorf("unknown output format '%v'", s3Select.Output.format))
})
return columnsKv, nil
} }
// This is the main function, It goes row by row and for records which validate // Evaluate - filters and sends records read from opened reader as per select statement to http response writer.
// the where clause it currently prints the appropriate row given the requested func (s3Select *S3Select) Evaluate(w http.ResponseWriter) {
// columns. getProgressFunc := s3Select.getProgress
func processSelectReq(reqColNames []string, alias string, wc sqlparser.Expr, lrecords int64, fnNames []string, rowCh chan Row, fn SelectFuncs, f format.Select) { if !s3Select.Progress.Enabled {
counter := -1 getProgressFunc = nil
filtrCount := 0
functionFlag := false
// Values used to store our aggregation values.
aggVals := make([]float64, len(reqColNames))
if lrecords == 0 {
lrecords = math.MaxInt64
} }
writer := newMessageWriter(w, getProgressFunc)
var results []string var inputRecord sql.Record
var columnsKv []columnKv var outputRecord sql.Record
if f.Type() == format.CSV { var err error
var err error var data []byte
columnsKv, err = columnsIndex(reqColNames, f) sendRecord := func() bool {
if err != nil { if outputRecord == nil {
rowCh <- Row{ return true
err: err,
}
return
} }
results = make([]string, len(columnsKv))
if data, err = s3Select.marshal(outputRecord); err != nil {
return false
}
if err = writer.SendRecords(data); err != nil {
// FIXME: log this error.
err = nil
return false
}
return true
} }
for { for {
record, err := f.Read() if inputRecord, err = s3Select.recordReader.Read(); err != nil {
if err != nil { if err != io.EOF {
rowCh <- Row{ break
err: err,
} }
return
} if s3Select.statement.IsAggregated() {
if record == nil { outputRecord = s3Select.outputRecord()
if functionFlag { if err = s3Select.statement.AggregateResult(outputRecord); err != nil {
rowCh <- Row{ break
record: aggFuncToStr(aggVals, f) + "\n", }
if !sendRecord() {
break
} }
} }
close(rowCh)
return
}
// For JSON multi-line input type columns needs if err = writer.SendStats(s3Select.getProgress()); err != nil {
// to be handled for each record. // FIXME: log this error.
if f.Type() == format.JSON { err = nil
columnsKv, err = columnsIndex(reqColNames, f)
if err != nil {
rowCh <- Row{
err: err,
}
return
} }
results = make([]string, len(columnsKv))
break
} }
f.UpdateBytesProcessed(int64(len(record))) outputRecord = s3Select.outputRecord()
if outputRecord, err = s3Select.statement.Eval(inputRecord, outputRecord); err != nil {
// Return in case the number of record reaches the LIMIT break
// defined in select query
if int64(filtrCount) == lrecords {
close(rowCh)
return
} }
// The call to the where function clause, ensures that if !s3Select.statement.IsAggregated() {
// the rows we print match our where clause. if !sendRecord() {
condition, err := matchesMyWhereClause(record, alias, wc) break
if err != nil {
rowCh <- Row{
err: err,
} }
return
} }
}
if condition { if err != nil {
// if its an asterix we just print everything in the row if serr := writer.SendError("InternalError", err.Error()); serr != nil {
if reqColNames[0] == "*" && fnNames[0] == "" { // FIXME: log errors.
switch f.OutputType() {
case format.CSV:
for i, kv := range columnsKv {
results[i] = gjson.GetBytes(record, kv.Key).String()
}
rowCh <- Row{
record: strings.Join(results, f.OutputFieldDelimiter()) + f.OutputRecordDelimiter(),
}
case format.JSON:
rowCh <- Row{
record: string(record) + f.OutputRecordDelimiter(),
}
}
} else if alias != "" {
// This is for dealing with the case of if we have to deal with a
// request for a column with an index e.g A_1.
if format.IsInt(reqColNames[0]) {
// This checks whether any aggregation function was called as now we
// no longer will go through printing each row, and only print at the end
if len(fnNames) > 0 && fnNames[0] != "" {
functionFlag = true
aggregationFns(counter, filtrCount, aggVals, reqColNames, fnNames, record)
} else {
// The code below finds the appropriate columns of the row given the
// indicies provided in the SQL request.
var rowStr string
rowStr, err = processColNameIndex(record, reqColNames, f)
if err != nil {
rowCh <- Row{
err: err,
}
return
}
rowCh <- Row{
record: rowStr + "\n",
}
}
} else {
// This code does aggregation if we were provided column names in the
// form of actual names rather an indices.
if len(fnNames) > 0 && fnNames[0] != "" {
functionFlag = true
aggregationFns(counter, filtrCount, aggVals, reqColNames, fnNames, record)
} else {
// This code prints the appropriate part of the row given the filter
// and select request, if the select request was based on column
// names rather than indices.
var rowStr string
rowStr, err = processColNameLiteral(record, reqColNames, fn, f)
if err != nil {
rowCh <- Row{
err: err,
}
return
}
rowCh <- Row{
record: rowStr + "\n",
}
}
}
}
filtrCount++
} }
counter++
} }
} }
// processColumnNames is a function which allows for cleaning of column names. // Close - closes opened S3 object.
func processColumnNames(reqColNames []string, alias string, f format.Select) error { func (s3Select *S3Select) Close() error {
switch f.Type() { return s3Select.recordReader.Close()
case format.CSV:
for i := range reqColNames {
// The code below basically cleans the column name of its alias and other
// syntax, so that we can extract its pure name.
reqColNames[i] = cleanCol(reqColNames[i], alias)
}
case format.JSON:
// JSON doesnt have columns so no cleaning required
}
return nil
} }
// processColNameIndex is the function which creates the row for an index based query. // NewS3Select - creates new S3Select by given request XML reader.
func processColNameIndex(record []byte, reqColNames []string, f format.Select) (string, error) { func NewS3Select(r io.Reader) (*S3Select, error) {
var row []string s3Select := &S3Select{}
for _, colName := range reqColNames { if err := xml.NewDecoder(r).Decode(s3Select); err != nil {
// COALESCE AND NULLIF do not support index based access. return nil, err
if reqColNames[0] == "0" { }
return "", format.ErrInvalidColumnIndex
}
cindex, err := strconv.Atoi(colName)
if err != nil {
return "", ErrMissingHeaders
}
if cindex > len(f.Header()) {
return "", format.ErrInvalidColumnIndex
}
// Subtract 1 because SELECT indexing is not 0 based, it return s3Select, nil
// starts at 1 generating the key like "_1".
row = append(row, gjson.GetBytes(record, string("_"+strconv.Itoa(cindex-1))).String())
}
rowStr := strings.Join(row, f.OutputFieldDelimiter())
if len(rowStr) > MaxCharsPerRecord {
return "", ErrOverMaxRecordSize
}
return rowStr, nil
}
// processColNameLiteral is the function which creates the row for an name based query.
func processColNameLiteral(record []byte, reqColNames []string, fn SelectFuncs, f format.Select) (string, error) {
row := make([]string, len(reqColNames))
for i, colName := range reqColNames {
// this is the case to deal with COALESCE.
if colName == "" && isValidFunc(fn.index, i) {
row[i] = evaluateFuncExpr(fn.funcExpr[i], "", record)
continue
}
row[i] = gjson.GetBytes(record, colName).String()
}
rowStr := strings.Join(row, f.OutputFieldDelimiter())
if len(rowStr) > MaxCharsPerRecord {
return "", ErrOverMaxRecordSize
}
return rowStr, nil
}
// aggregationFns is a function which performs the actual aggregation
// methods on the given row, it uses an array defined in the main parsing
// function to keep track of values.
func aggregationFns(counter int, filtrCount int, aggVals []float64, storeReqCols []string, storeFns []string, record []byte) error {
for i, storeFn := range storeFns {
switch storeFn {
case "":
continue
case "count":
aggVals[i]++
default:
// Column names are provided as an index it'll use
// this if statement instead.
var convAggFloat float64
if format.IsInt(storeReqCols[i]) {
index, _ := strconv.Atoi(storeReqCols[i])
convAggFloat = gjson.GetBytes(record, "_"+strconv.Itoa(index)).Float()
} else {
// Named columns rather than indices.
convAggFloat = gjson.GetBytes(record, storeReqCols[i]).Float()
}
switch storeFn {
case "min":
if counter == -1 {
aggVals[i] = math.MaxFloat64
}
if convAggFloat < aggVals[i] {
aggVals[i] = convAggFloat
}
case "max":
// Calculate the max.
if counter == -1 {
aggVals[i] = math.SmallestNonzeroFloat64
}
if convAggFloat > aggVals[i] {
aggVals[i] = convAggFloat
}
case "sum":
// Calculate the sum.
aggVals[i] += convAggFloat
case "avg":
// Calculating the average.
if filtrCount == 0 {
aggVals[i] = convAggFloat
} else {
aggVals[i] = (convAggFloat + (aggVals[i] * float64(filtrCount))) / float64((filtrCount + 1))
}
default:
return ErrParseNonUnaryAgregateFunctionCall
}
}
}
return nil
} }

View File

@ -0,0 +1,170 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
import (
"bytes"
"encoding/csv"
"io"
"io/ioutil"
"math/rand"
"net/http"
"strconv"
"testing"
"time"
humanize "github.com/dustin/go-humanize"
)
var randSrc = rand.New(rand.NewSource(time.Now().UnixNano()))
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func newRandString(length int) string {
b := make([]byte, length)
for i := range b {
b[i] = charset[randSrc.Intn(len(charset))]
}
return string(b)
}
func genSampleCSVData(count int) []byte {
buf := &bytes.Buffer{}
csvWriter := csv.NewWriter(buf)
csvWriter.Write([]string{"id", "name", "age", "city"})
for i := 0; i < count; i++ {
csvWriter.Write([]string{
strconv.Itoa(i),
newRandString(10),
newRandString(5),
newRandString(10),
})
}
csvWriter.Flush()
return buf.Bytes()
}
type nullResponseWriter struct {
}
func (w *nullResponseWriter) Header() http.Header {
return nil
}
func (w *nullResponseWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *nullResponseWriter) WriteHeader(statusCode int) {
}
func (w *nullResponseWriter) Flush() {
}
func benchmarkSelect(b *testing.B, count int, query string) {
var requestXML = []byte(`
<?xml version="1.0" encoding="UTF-8"?>
<SelectObjectContentRequest>
<Expression>` + query + `</Expression>
<ExpressionType>SQL</ExpressionType>
<InputSerialization>
<CompressionType>NONE</CompressionType>
<CSV>
<FileHeaderInfo>USE</FileHeaderInfo>
</CSV>
</InputSerialization>
<OutputSerialization>
<CSV>
</CSV>
</OutputSerialization>
<RequestProgress>
<Enabled>FALSE</Enabled>
</RequestProgress>
</SelectObjectContentRequest>
`)
s3Select, err := NewS3Select(bytes.NewReader(requestXML))
if err != nil {
b.Fatal(err)
}
csvData := genSampleCSVData(count)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewReader(csvData)), nil
}); err != nil {
b.Fatal(err)
}
s3Select.Evaluate(&nullResponseWriter{})
s3Select.Close()
}
}
func benchmarkSelectAll(b *testing.B, count int) {
benchmarkSelect(b, count, "select * from S3Object")
}
// BenchmarkSelectAll_100K - benchmark * function with 100k records.
func BenchmarkSelectAll_100K(b *testing.B) {
benchmarkSelectAll(b, 100*humanize.KiByte)
}
// BenchmarkSelectAll_1M - benchmark * function with 1m records.
func BenchmarkSelectAll_1M(b *testing.B) {
benchmarkSelectAll(b, 1*humanize.MiByte)
}
// BenchmarkSelectAll_2M - benchmark * function with 2m records.
func BenchmarkSelectAll_2M(b *testing.B) {
benchmarkSelectAll(b, 2*humanize.MiByte)
}
// BenchmarkSelectAll_10M - benchmark * function with 10m records.
func BenchmarkSelectAll_10M(b *testing.B) {
benchmarkSelectAll(b, 10*humanize.MiByte)
}
func benchmarkAggregateCount(b *testing.B, count int) {
benchmarkSelect(b, count, "select count(*) from S3Object")
}
// BenchmarkAggregateCount_100K - benchmark count(*) function with 100k records.
func BenchmarkAggregateCount_100K(b *testing.B) {
benchmarkAggregateCount(b, 100*humanize.KiByte)
}
// BenchmarkAggregateCount_1M - benchmark count(*) function with 1m records.
func BenchmarkAggregateCount_1M(b *testing.B) {
benchmarkAggregateCount(b, 1*humanize.MiByte)
}
// BenchmarkAggregateCount_2M - benchmark count(*) function with 2m records.
func BenchmarkAggregateCount_2M(b *testing.B) {
benchmarkAggregateCount(b, 2*humanize.MiByte)
}
// BenchmarkAggregateCount_10M - benchmark count(*) function with 10m records.
func BenchmarkAggregateCount_10M(b *testing.B) {
benchmarkAggregateCount(b, 10*humanize.MiByte)
}

View File

@ -1,5 +1,5 @@
/* /*
* Minio Cloud Storage, (C) 2018 Minio, Inc. * Minio Cloud Storage, (C) 2019 Minio, Inc.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,594 +18,200 @@ package s3select
import ( import (
"bytes" "bytes"
"encoding/csv" "go/build"
"fmt" "io"
"math/rand" "io/ioutil"
"strconv" "net/http"
"os"
"path"
"reflect"
"testing" "testing"
"time"
humanize "github.com/dustin/go-humanize"
"github.com/tidwall/gjson"
"github.com/minio/minio/pkg/s3select/format"
) )
// This function returns the index of a string in a list type testResponseWriter struct {
func stringIndex(a string, list []string) int { statusCode int
for i, v := range list { response []byte
if v == a {
return i
}
}
return -1
} }
// TestHelperFunctions is a unit test which tests some func (w *testResponseWriter) Header() http.Header {
// small helper string functions. return nil
func TestHelperFunctions(t *testing.T) {
tables := []struct {
myReq string
myList []string
myIndex int
expected bool
}{
{"test1", []string{"test1", "test2", "test3", "test4", "test5"}, 0, true},
{"random", []string{"test1", "test2", "test3", "test4", "test5"}, -1, false},
{"test3", []string{"test1", "test2", "test3", "test4", "test5"}, 2, true},
}
for _, table := range tables {
if format.StringInSlice(table.myReq, table.myList) != table.expected {
t.Error()
}
if stringIndex(table.myReq, table.myList) != table.myIndex {
t.Error()
}
}
} }
// TestStateMachine is a unit test which ensures that the lowest level of the func (w *testResponseWriter) Write(p []byte) (int, error) {
// interpreter is converting properly. w.response = append(w.response, p...)
func TestStateMachine(t *testing.T) { return len(p), nil
tables := []struct {
operand string
operator string
leftArg string
err error
expected bool
}{
{"", ">", "2012", nil, true},
{"2005", ">", "2012", nil, true},
{"2005", ">", "2012", nil, true},
{"2012.0000", ">", "2014.000", nil, true},
{"2012", "!=", "2014.000", nil, true},
{"NA", ">", "2014.000", nil, true},
{"2012", ">", "2014.000", nil, false},
{"2012.0000", ">", "2014", nil, false},
{"", "<", "2012", nil, false},
{"2012.0000", "<", "2014.000", nil, false},
{"2014", ">", "Random", nil, false},
{"test3", ">", "aandom", nil, false},
{"true", ">", "true", ErrUnsupportedSyntax, false},
}
for i, table := range tables {
val, err := evaluateOperator(gjson.Parse(table.leftArg), table.operator, gjson.Parse(table.operand))
if err != table.err {
t.Errorf("Test %d: expected %v, got %v", i+1, table.err, err)
}
if val != table.expected {
t.Errorf("Test %d: expected %t, got %t", i+1, table.expected, val)
}
}
} }
// TestOperators is a unit test which ensures that the appropriate values are func (w *testResponseWriter) WriteHeader(statusCode int) {
// being returned from the operators functions. w.statusCode = statusCode
func TestOperators(t *testing.T) {
tables := []struct {
operator string
err error
}{
{">", nil},
{"%", ErrParseUnknownOperator},
}
for _, table := range tables {
err := checkValidOperator(table.operator)
if err != table.err {
t.Error()
}
}
} }
// Unit tests for the main function that performs aggreggation. func (w *testResponseWriter) Flush() {
func TestAggregationFunc(t *testing.T) {
columnsMap := make(map[string]int)
columnsMap["Col1"] = 0
columnsMap["Col2"] = 1
tables := []struct {
counter int
filtrCount int
myAggVals []float64
columnsMap map[string]int
storeReqCols []string
storeFunctions []string
record []byte
err error
expectedVal float64
}{
{10, 5, []float64{10, 11, 12, 13, 14}, columnsMap, []string{"Col1"}, []string{"count"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"), nil, 11},
{10, 5, []float64{10}, columnsMap, []string{"Col1"}, []string{"min"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"), nil, 1},
{10, 5, []float64{10}, columnsMap, []string{"Col1"}, []string{"max"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"), nil, 10},
{10, 5, []float64{10}, columnsMap, []string{"Col1"}, []string{"sum"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"), nil, 11},
{1, 1, []float64{10}, columnsMap, []string{"Col1"}, []string{"avg"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"), nil, 5.500},
{10, 5, []float64{0.0000}, columnsMap, []string{"Col1"}, []string{"random"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"),
ErrParseNonUnaryAgregateFunctionCall, 0},
{0, 5, []float64{0}, columnsMap, []string{"0"}, []string{"count"}, []byte("{\"Col1\":\"1\",\"Col2\":\"2\"}"), nil, 1},
{10, 5, []float64{10}, columnsMap, []string{"1"}, []string{"min"}, []byte("{\"_1\":\"1\",\"_2\":\"2\"}"), nil, 1},
}
for _, table := range tables {
err := aggregationFns(table.counter, table.filtrCount, table.myAggVals, table.storeReqCols, table.storeFunctions, table.record)
if table.err != err {
t.Error()
}
if table.myAggVals[0] != table.expectedVal {
t.Error()
}
}
} }
// TestStringComparator is a unit test which ensures that the appropriate func TestCSVINput(t *testing.T) {
// values are being compared for strings. var requestXML = []byte(`
func TestStringComparator(t *testing.T) { <?xml version="1.0" encoding="UTF-8"?>
tables := []struct { <SelectObjectContentRequest>
operand string <Expression>SELECT one, two, three from S3Object</Expression>
operator string <ExpressionType>SQL</ExpressionType>
myVal string <InputSerialization>
expected bool <CompressionType>NONE</CompressionType>
err error <CSV>
}{ <FileHeaderInfo>USE</FileHeaderInfo>
{"random", ">", "myName", "random" > "myName", nil}, </CSV>
{"12", "!=", "myName", "12" != "myName", nil}, </InputSerialization>
{"12", "=", "myName", "12" == "myName", nil}, <OutputSerialization>
{"12", "<=", "myName", "12" <= "myName", nil}, <CSV>
{"12", ">=", "myName", "12" >= "myName", nil}, </CSV>
{"12", "<", "myName", "12" < "myName", nil}, </OutputSerialization>
{"name", "like", "_x%", false, nil}, <RequestProgress>
{"12", "randomoperator", "myName", false, ErrUnsupportedSyntax}, <Enabled>FALSE</Enabled>
} </RequestProgress>
for _, table := range tables { </SelectObjectContentRequest>
myVal, err := stringEval(table.operand, table.operator, table.myVal) `)
if err != table.err {
t.Error()
}
if myVal != table.expected {
t.Error()
}
}
}
// TestFloatComparator is a unit test which ensures that the appropriate var csvData = []byte(`one,two,three
// values are being compared for floats. 10,true,"foo"
func TestFloatComparator(t *testing.T) { -3,false,"bar baz"
tables := []struct { `)
operand float64
operator string
myVal float64
expected bool
err error
}{
{12.000, ">", 13.000, 12.000 > 13.000, nil},
{1000.000, "!=", 1000.000, 1000.000 != 1000.000, nil},
{1000.000, "<", 1000.000, 1000.000 < 1000.000, nil},
{1000.000, "<=", 1000.000, 1000.000 <= 1000.000, nil},
{1000.000, ">=", 1000.000, 1000.000 >= 1000.000, nil},
{1000.000, "=", 1000.000, 1000.000 == 1000.000, nil},
{17.000, "randomoperator", 0.0, false, ErrUnsupportedSyntax},
}
for _, table := range tables {
myVal, err := floatEval(table.operand, table.operator, table.myVal)
if err != table.err {
t.Error()
}
if myVal != table.expected {
t.Error()
}
}
}
// TestIntComparator is a unit test which ensures that the appropriate values var expectedResult = []byte{
// are being compared for ints. 0, 0, 0, 113, 0, 0, 0, 85, 186, 145, 179, 109, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 49, 48, 44, 116, 114, 117, 101, 44, 102, 111, 111, 10, 225, 160, 249, 157, 0, 0, 0, 118, 0, 0, 0, 85, 8, 177, 111, 125, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 45, 51, 44, 102, 97, 108, 115, 101, 44, 98, 97, 114, 32, 98, 97, 122, 10, 120, 72, 77, 126, 0, 0, 0, 235, 0, 0, 0, 67, 213, 243, 57, 141, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 8, 116, 101, 120, 116, 47, 120, 109, 108, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 5, 83, 116, 97, 116, 115, 60, 63, 120, 109, 108, 32, 118, 101, 114, 115, 105, 111, 110, 61, 34, 49, 46, 48, 34, 32, 101, 110, 99, 111, 100, 105, 110, 103, 61, 34, 85, 84, 70, 45, 56, 34, 63, 62, 60, 83, 116, 97, 116, 115, 62, 60, 66, 121, 116, 101, 115, 83, 99, 97, 110, 110, 101, 100, 62, 52, 55, 60, 47, 66, 121, 116, 101, 115, 83, 99, 97, 110, 110, 101, 100, 62, 60, 66, 121, 116, 101, 115, 80, 114, 111, 99, 101, 115, 115, 101, 100, 62, 52, 55, 60, 47, 66, 121, 116, 101, 115, 80, 114, 111, 99, 101, 115, 115, 101, 100, 62, 60, 66, 121, 116, 101, 115, 82, 101, 116, 117, 114, 110, 101, 100, 62, 50, 57, 60, 47, 66, 121, 116, 101, 115, 82, 101, 116, 117, 114, 110, 101, 100, 62, 60, 47, 83, 116, 97, 116, 115, 62, 214, 225, 163, 199, 0, 0, 0, 56, 0, 0, 0, 40, 193, 198, 132, 212, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 3, 69, 110, 100, 207, 151, 211, 146,
func TestIntComparator(t *testing.T) {
tables := []struct {
operand int64
operator string
myVal int64
expected bool
err error
}{
{12, ">", 13, 12.000 > 13.000, nil},
{1000, "!=", 1000, 1000.000 != 1000.000, nil},
{1000, "<", 1000, 1000.000 < 1000.000, nil},
{1000, "<=", 1000, 1000.000 <= 1000.000, nil},
{1000, ">=", 1000, 1000.000 >= 1000.000, nil},
{1000, "=", 1000, 1000.000 >= 1000.000, nil},
{17, "randomoperator", 0, false, ErrUnsupportedSyntax},
}
for _, table := range tables {
myVal, err := intEval(table.operand, table.operator, table.myVal)
if err != table.err {
t.Error()
}
if myVal != table.expected {
t.Error()
}
}
}
// TestSizeFunction is a function which provides unit testing for the function
// which calculates size.
func TestSizeFunction(t *testing.T) {
tables := []struct {
myRecord []string
expected int64
}{
{[]string{"test1", "test2", "test3", "test4", "test5"}, 30},
}
for _, table := range tables {
if format.ProcessSize(table.myRecord) != table.expected {
t.Error()
}
}
}
func TestMatch(t *testing.T) {
testCases := []struct {
pattern string
text string
matched bool
}{
// Test case - 1.
// Test case so that the match occurs on the opening letter.
{
pattern: "a%",
text: "apple",
matched: true,
},
// Test case - 2.
// Test case so that the ending letter is true.
{
pattern: "%m",
text: "random",
matched: true,
},
// Test case - 3.
// Test case so that a character is at the appropriate position.
{
pattern: "_d%",
text: "adam",
matched: true,
},
// Test case - 4.
// Test case so that a character is at the appropriate position.
{
pattern: "_d%",
text: "apple",
matched: false,
},
// Test case - 5.
// Test case with checking that it is at least 3 in length
{
pattern: "a_%_%",
text: "ap",
matched: false,
},
{
pattern: "a_%_%",
text: "apple",
matched: true,
},
{
pattern: "%or%",
text: "orphan",
matched: true,
},
{
pattern: "%or%",
text: "dolphin",
matched: false,
},
{
pattern: "%or%",
text: "dorlphin",
matched: true,
},
{
pattern: "2__3",
text: "2003",
matched: true,
},
{
pattern: "_YYYY_",
text: "aYYYYa",
matched: true,
},
{
pattern: "C%",
text: "CA",
matched: true,
},
{
pattern: "C%",
text: "SC",
matched: false,
},
{
pattern: "%C",
text: "SC",
matched: true,
},
{
pattern: "%C",
text: "CA",
matched: false,
},
{
pattern: "%C",
text: "ACCC",
matched: true,
},
{
pattern: "C%",
text: "CCC",
matched: true,
},
{
pattern: "j%",
text: "mejri",
matched: false,
},
{
pattern: "a%o",
text: "ando",
matched: true,
},
{
pattern: "%j",
text: "mejri",
matched: false,
},
{
pattern: "%ja",
text: "mejrija",
matched: true,
},
{
pattern: "ja%",
text: "jamal",
matched: true,
},
{
pattern: "a%o",
text: "andp",
matched: false,
},
{
pattern: "_r%",
text: "arpa",
matched: true,
},
{
pattern: "_r%",
text: "apra",
matched: false,
},
{
pattern: "a_%_%",
text: "appple",
matched: true,
},
{
pattern: "l_b%",
text: "lebron",
matched: true,
},
{
pattern: "leb%",
text: "Dalembert",
matched: false,
},
{
pattern: "leb%",
text: "Landesberg",
matched: false,
},
{
pattern: "leb%",
text: "Mccalebb",
matched: false,
},
{
pattern: "%lebb",
text: "Mccalebb",
matched: true,
},
}
// Iterating over the test cases, call the function under test and asert the output.
for i, testCase := range testCases {
actualResult, err := likeConvert(testCase.pattern, testCase.text)
if err != nil {
t.Error()
}
if testCase.matched != actualResult {
fmt.Println("Expected Pattern", testCase.pattern, "Expected Text", testCase.text)
t.Errorf("Test %d: Expected the result to be `%v`, but instead found it to be `%v`", i+1, testCase.matched, actualResult)
}
}
}
// TestFuncProcessing is a unit test which ensures that the appropriate values are
// being returned from the Processing... functions.
func TestFuncProcessing(t *testing.T) {
tables := []struct {
myString string
coalList []string
myValString string
myValCoal string
myValNull string
stringFunc string
}{
{"lower", []string{"random", "hello", "random"}, "LOWER", "random", "", "UPPER"},
{"LOWER", []string{"missing", "hello", "random"}, "lower", "hello", "null", "LOWER"},
}
for _, table := range tables {
if table.coalList != nil {
myVal := processCoalNoIndex(table.coalList)
if myVal != table.myValCoal {
t.Error()
}
}
myVal := applyStrFunc(gjson.Result{
Type: gjson.String,
Str: table.myString,
}, table.stringFunc)
if myVal != table.myValString {
t.Error()
}
}
}
const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
func StringWithCharset(length int, charset string) string {
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}
func String(length int) string {
return StringWithCharset(length, charset)
}
func genCSV(b *bytes.Buffer, records int) error {
b.Reset()
w := csv.NewWriter(b)
w.Write([]string{"id", "name", "age", "city"})
for i := 0; i < records; i++ {
w.Write([]string{
strconv.Itoa(i),
String(10),
String(5),
String(10),
})
} }
// Write any buffered data to the underlying writer (standard output). s3Select, err := NewS3Select(bytes.NewReader(requestXML))
w.Flush()
return w.Error()
}
func benchmarkSQLAll(b *testing.B, records int) {
benchmarkSQL(b, records, "select * from S3Object")
}
func benchmarkSQLAggregate(b *testing.B, records int) {
benchmarkSQL(b, records, "select count(*) from S3Object")
}
func benchmarkSQL(b *testing.B, records int, query string) {
var (
buf bytes.Buffer
output bytes.Buffer
)
genCSV(&buf, records)
b.ResetTimer()
b.ReportAllocs()
sreq := ObjectSelectRequest{}
sreq.Expression = query
sreq.ExpressionType = QueryExpressionTypeSQL
sreq.InputSerialization.CSV = &struct {
FileHeaderInfo CSVFileHeaderInfo
RecordDelimiter string
FieldDelimiter string
QuoteCharacter string
QuoteEscapeCharacter string
Comments string
}{}
sreq.InputSerialization.CSV.FileHeaderInfo = CSVFileHeaderInfoUse
sreq.InputSerialization.CSV.RecordDelimiter = "\n"
sreq.InputSerialization.CSV.FieldDelimiter = ","
sreq.OutputSerialization.CSV = &struct {
QuoteFields CSVQuoteFields
RecordDelimiter string
FieldDelimiter string
QuoteCharacter string
QuoteEscapeCharacter string
}{}
sreq.OutputSerialization.CSV.RecordDelimiter = "\n"
sreq.OutputSerialization.CSV.FieldDelimiter = ","
s3s, err := New(&buf, int64(buf.Len()), sreq)
if err != nil { if err != nil {
b.Fatal(err) t.Fatal(err)
} }
for i := 0; i < b.N; i++ { if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) {
output.Reset() return ioutil.NopCloser(bytes.NewReader(csvData)), nil
if err = Execute(&output, s3s); err != nil { }); err != nil {
b.Fatal(err) t.Fatal(err)
}
w := &testResponseWriter{}
s3Select.Evaluate(w)
s3Select.Close()
if !reflect.DeepEqual(w.response, expectedResult) {
t.Fatalf("received response does not match with expected reply")
}
}
func TestJSONInput(t *testing.T) {
var requestXML = []byte(`
<?xml version="1.0" encoding="UTF-8"?>
<SelectObjectContentRequest>
<Expression>SELECT one, two, three from S3Object</Expression>
<ExpressionType>SQL</ExpressionType>
<InputSerialization>
<CompressionType>NONE</CompressionType>
<JSON>
<Type>DOCUMENT</Type>
</JSON>
</InputSerialization>
<OutputSerialization>
<CSV>
</CSV>
</OutputSerialization>
<RequestProgress>
<Enabled>FALSE</Enabled>
</RequestProgress>
</SelectObjectContentRequest>
`)
var jsonData = []byte(`{"one":10,"two":true,"three":"foo"}
{"one":-3,"two":true,"three":"bar baz"}
`)
var expectedResult = []byte{
0, 0, 0, 113, 0, 0, 0, 85, 186, 145, 179, 109, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 49, 48, 44, 116, 114, 117, 101, 44, 102, 111, 111, 10, 225, 160, 249, 157, 0, 0, 0, 117, 0, 0, 0, 85, 79, 17, 21, 173, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 45, 51, 44, 116, 114, 117, 101, 44, 98, 97, 114, 32, 98, 97, 122, 10, 34, 12, 125, 218, 0, 0, 0, 235, 0, 0, 0, 67, 213, 243, 57, 141, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 8, 116, 101, 120, 116, 47, 120, 109, 108, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 5, 83, 116, 97, 116, 115, 60, 63, 120, 109, 108, 32, 118, 101, 114, 115, 105, 111, 110, 61, 34, 49, 46, 48, 34, 32, 101, 110, 99, 111, 100, 105, 110, 103, 61, 34, 85, 84, 70, 45, 56, 34, 63, 62, 60, 83, 116, 97, 116, 115, 62, 60, 66, 121, 116, 101, 115, 83, 99, 97, 110, 110, 101, 100, 62, 55, 54, 60, 47, 66, 121, 116, 101, 115, 83, 99, 97, 110, 110, 101, 100, 62, 60, 66, 121, 116, 101, 115, 80, 114, 111, 99, 101, 115, 115, 101, 100, 62, 55, 54, 60, 47, 66, 121, 116, 101, 115, 80, 114, 111, 99, 101, 115, 115, 101, 100, 62, 60, 66, 121, 116, 101, 115, 82, 101, 116, 117, 114, 110, 101, 100, 62, 50, 56, 60, 47, 66, 121, 116, 101, 115, 82, 101, 116, 117, 114, 110, 101, 100, 62, 60, 47, 83, 116, 97, 116, 115, 62, 124, 107, 174, 242, 0, 0, 0, 56, 0, 0, 0, 40, 193, 198, 132, 212, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 3, 69, 110, 100, 207, 151, 211, 146,
}
s3Select, err := NewS3Select(bytes.NewReader(requestXML))
if err != nil {
t.Fatal(err)
}
if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewReader(jsonData)), nil
}); err != nil {
t.Fatal(err)
}
w := &testResponseWriter{}
s3Select.Evaluate(w)
s3Select.Close()
if !reflect.DeepEqual(w.response, expectedResult) {
t.Fatalf("received response does not match with expected reply")
}
}
func TestParquetInput(t *testing.T) {
var requestXML = []byte(`
<?xml version="1.0" encoding="UTF-8"?>
<SelectObjectContentRequest>
<Expression>SELECT one, two, three from S3Object</Expression>
<ExpressionType>SQL</ExpressionType>
<InputSerialization>
<CompressionType>NONE</CompressionType>
<Parquet>
</Parquet>
</InputSerialization>
<OutputSerialization>
<CSV>
</CSV>
</OutputSerialization>
<RequestProgress>
<Enabled>FALSE</Enabled>
</RequestProgress>
</SelectObjectContentRequest>
`)
getReader := func(offset int64, length int64) (io.ReadCloser, error) {
testdataFile := path.Join(build.Default.GOPATH, "src/github.com/minio/minio/pkg/s3select/testdata.parquet")
file, err := os.Open(testdataFile)
if err != nil {
return nil, err
} }
fi, err := file.Stat()
if err != nil {
return nil, err
}
if offset < 0 {
offset = fi.Size() + offset
}
if _, err = file.Seek(offset, os.SEEK_SET); err != nil {
return nil, err
}
return file, nil
}
var expectedResult = []byte{
0, 0, 0, 114, 0, 0, 0, 85, 253, 49, 201, 189, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 50, 46, 53, 44, 102, 111, 111, 44, 116, 114, 117, 101, 10, 209, 8, 249, 77, 0, 0, 0, 114, 0, 0, 0, 85, 253, 49, 201, 189, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 45, 49, 44, 98, 97, 114, 44, 102, 97, 108, 115, 101, 10, 45, 143, 126, 67, 0, 0, 0, 113, 0, 0, 0, 85, 186, 145, 179, 109, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 24, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 101, 97, 109, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 7, 82, 101, 99, 111, 114, 100, 115, 45, 49, 44, 98, 97, 122, 44, 116, 114, 117, 101, 10, 230, 139, 42, 176, 0, 0, 0, 235, 0, 0, 0, 67, 213, 243, 57, 141, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 13, 58, 99, 111, 110, 116, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 8, 116, 101, 120, 116, 47, 120, 109, 108, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 5, 83, 116, 97, 116, 115, 60, 63, 120, 109, 108, 32, 118, 101, 114, 115, 105, 111, 110, 61, 34, 49, 46, 48, 34, 32, 101, 110, 99, 111, 100, 105, 110, 103, 61, 34, 85, 84, 70, 45, 56, 34, 63, 62, 60, 83, 116, 97, 116, 115, 62, 60, 66, 121, 116, 101, 115, 83, 99, 97, 110, 110, 101, 100, 62, 45, 49, 60, 47, 66, 121, 116, 101, 115, 83, 99, 97, 110, 110, 101, 100, 62, 60, 66, 121, 116, 101, 115, 80, 114, 111, 99, 101, 115, 115, 101, 100, 62, 45, 49, 60, 47, 66, 121, 116, 101, 115, 80, 114, 111, 99, 101, 115, 115, 101, 100, 62, 60, 66, 121, 116, 101, 115, 82, 101, 116, 117, 114, 110, 101, 100, 62, 51, 56, 60, 47, 66, 121, 116, 101, 115, 82, 101, 116, 117, 114, 110, 101, 100, 62, 60, 47, 83, 116, 97, 116, 115, 62, 199, 176, 2, 83, 0, 0, 0, 56, 0, 0, 0, 40, 193, 198, 132, 212, 13, 58, 109, 101, 115, 115, 97, 103, 101, 45, 116, 121, 112, 101, 7, 0, 5, 101, 118, 101, 110, 116, 11, 58, 101, 118, 101, 110, 116, 45, 116, 121, 112, 101, 7, 0, 3, 69, 110, 100, 207, 151, 211, 146,
}
s3Select, err := NewS3Select(bytes.NewReader(requestXML))
if err != nil {
t.Fatal(err)
}
if err = s3Select.Open(getReader); err != nil {
t.Fatal(err)
}
w := &testResponseWriter{}
s3Select.Evaluate(w)
s3Select.Close()
if !reflect.DeepEqual(w.response, expectedResult) {
t.Fatalf("received response does not match with expected reply")
} }
} }
// BenchmarkSQLAggregate_100K - benchmark count(*) function with 100k records.
func BenchmarkSQLAggregate_100K(b *testing.B) {
benchmarkSQLAggregate(b, humanize.KiByte*100)
}
// BenchmarkSQLAggregate_1M - benchmark count(*) function with 1m records.
func BenchmarkSQLAggregate_1M(b *testing.B) {
benchmarkSQLAggregate(b, humanize.MiByte)
}
// BenchmarkSQLAggregate_2M - benchmark count(*) function with 2m records.
func BenchmarkSQLAggregate_2M(b *testing.B) {
benchmarkSQLAggregate(b, 2*humanize.MiByte)
}
// BenchmarkSQLAggregate_10M - benchmark count(*) function with 10m records.
func BenchmarkSQLAggregate_10M(b *testing.B) {
benchmarkSQLAggregate(b, 10*humanize.MiByte)
}
// BenchmarkSQLAll_100K - benchmark * function with 100k records.
func BenchmarkSQLAll_100K(b *testing.B) {
benchmarkSQLAll(b, humanize.KiByte*100)
}
// BenchmarkSQLAll_1M - benchmark * function with 1m records.
func BenchmarkSQLAll_1M(b *testing.B) {
benchmarkSQLAll(b, humanize.MiByte)
}
// BenchmarkSQLAll_2M - benchmark * function with 2m records.
func BenchmarkSQLAll_2M(b *testing.B) {
benchmarkSQLAll(b, 2*humanize.MiByte)
}
// BenchmarkSQLAll_10M - benchmark * function with 10m records.
func BenchmarkSQLAll_10M(b *testing.B) {
benchmarkSQLAll(b, 10*humanize.MiByte)
}

View File

@ -0,0 +1,175 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "fmt"
// ArithOperator - arithmetic operator.
type ArithOperator string
const (
// Add operator '+'.
Add ArithOperator = "+"
// Subtract operator '-'.
Subtract ArithOperator = "-"
// Multiply operator '*'.
Multiply ArithOperator = "*"
// Divide operator '/'.
Divide ArithOperator = "/"
// Modulo operator '%'.
Modulo ArithOperator = "%"
)
// arithExpr - arithmetic function.
type arithExpr struct {
left Expr
right Expr
operator ArithOperator
funcType Type
}
// String - returns string representation of this function.
func (f *arithExpr) String() string {
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
}
func (f *arithExpr) compute(lv, rv *Value) (*Value, error) {
leftValueType := lv.Type()
rightValueType := rv.Type()
if !leftValueType.isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValueType)
return nil, errExternalEvalException(err)
}
if !rightValueType.isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValueType)
return nil, errExternalEvalException(err)
}
leftValue := lv.FloatValue()
rightValue := rv.FloatValue()
var result float64
switch f.operator {
case Add:
result = leftValue + rightValue
case Subtract:
result = leftValue - rightValue
case Multiply:
result = leftValue * rightValue
case Divide:
result = leftValue / rightValue
case Modulo:
result = float64(int64(leftValue) % int64(rightValue))
}
if leftValueType == Float || rightValueType == Float {
return NewFloat(result), nil
}
return NewInt(int64(result)), nil
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *arithExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
return nil, nil
}
return f.compute(leftValue, rightValue)
}
// AggregateValue - returns aggregated value.
func (f *arithExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
lv, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
rv, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
return f.compute(lv, rv)
}
// Type - returns arithmeticFunction or aggregateFunction type.
func (f *arithExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Float as return type.
func (f *arithExpr) ReturnType() Type {
return Float
}
// newArithExpr - creates new arithmetic function.
func newArithExpr(operator ArithOperator, left, right Expr) (*arithExpr, error) {
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := arithmeticFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &arithExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
}

View File

@ -0,0 +1,636 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"regexp"
"strings"
)
// ComparisonOperator - comparison operator.
type ComparisonOperator string
const (
// Equal operator '='.
Equal ComparisonOperator = "="
// NotEqual operator '!=' or '<>'.
NotEqual ComparisonOperator = "!="
// LessThan operator '<'.
LessThan ComparisonOperator = "<"
// GreaterThan operator '>'.
GreaterThan ComparisonOperator = ">"
// LessThanEqual operator '<='.
LessThanEqual ComparisonOperator = "<="
// GreaterThanEqual operator '>='.
GreaterThanEqual ComparisonOperator = ">="
// Between operator 'BETWEEN'
Between ComparisonOperator = "between"
// In operator 'IN'
In ComparisonOperator = "in"
// Like operator 'LIKE'
Like ComparisonOperator = "like"
// NotBetween operator 'NOT BETWEEN'
NotBetween ComparisonOperator = "not between"
// NotIn operator 'NOT IN'
NotIn ComparisonOperator = "not in"
// NotLike operator 'NOT LIKE'
NotLike ComparisonOperator = "not like"
// IsNull operator 'IS NULL'
IsNull ComparisonOperator = "is null"
// IsNotNull operator 'IS NOT NULL'
IsNotNull ComparisonOperator = "is not null"
)
// String - returns string representation of this operator.
func (operator ComparisonOperator) String() string {
return strings.ToUpper((string(operator)))
}
func equal(leftValue, rightValue *Value) (bool, error) {
switch {
case leftValue.Type() == Null && rightValue.Type() == Null:
return true, nil
case leftValue.Type() == Bool && rightValue.Type() == Bool:
return leftValue.BoolValue() == rightValue.BoolValue(), nil
case (leftValue.Type() == Int || leftValue.Type() == Float) &&
(rightValue.Type() == Int || rightValue.Type() == Float):
return leftValue.FloatValue() == rightValue.FloatValue(), nil
case leftValue.Type() == String && rightValue.Type() == String:
return leftValue.StringValue() == rightValue.StringValue(), nil
case leftValue.Type() == Timestamp && rightValue.Type() == Timestamp:
return leftValue.TimeValue() == rightValue.TimeValue(), nil
}
return false, fmt.Errorf("left value type %v and right value type %v are incompatible for equality check", leftValue.Type(), rightValue.Type())
}
// comparisonExpr - comparison function.
type comparisonExpr struct {
left Expr
right Expr
to Expr
operator ComparisonOperator
funcType Type
}
// String - returns string representation of this function.
func (f *comparisonExpr) String() string {
switch f.operator {
case Equal, NotEqual, LessThan, GreaterThan, LessThanEqual, GreaterThanEqual, In, Like, NotIn, NotLike:
return fmt.Sprintf("(%v %v %v)", f.left, f.operator, f.right)
case Between, NotBetween:
return fmt.Sprintf("(%v %v %v AND %v)", f.left, f.operator, f.right, f.to)
}
return fmt.Sprintf("(%v %v %v %v)", f.left, f.right, f.to, f.operator)
}
func (f *comparisonExpr) equal(leftValue, rightValue *Value) (*Value, error) {
result, err := equal(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(result), nil
}
func (f *comparisonExpr) notEqual(leftValue, rightValue *Value) (*Value, error) {
result, err := equal(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(!result), nil
}
func (f *comparisonExpr) lessThan(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() < rightValue.FloatValue()), nil
}
func (f *comparisonExpr) greaterThan(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() > rightValue.FloatValue()), nil
}
func (f *comparisonExpr) lessThanEqual(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() <= rightValue.FloatValue()), nil
}
func (f *comparisonExpr) greaterThanEqual(leftValue, rightValue *Value) (*Value, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !rightValue.Type().isNumber() {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to number", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(leftValue.FloatValue() >= rightValue.FloatValue()), nil
}
func (f *comparisonExpr) computeBetween(leftValue, fromValue, toValue *Value) (bool, error) {
if !leftValue.Type().isNumber() {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to number", f, leftValue.Type())
return false, errExternalEvalException(err)
}
if !fromValue.Type().isNumber() {
err := fmt.Errorf("%v: from side expression evaluated to %v; not to number", f, fromValue.Type())
return false, errExternalEvalException(err)
}
if !toValue.Type().isNumber() {
err := fmt.Errorf("%v: to side expression evaluated to %v; not to number", f, toValue.Type())
return false, errExternalEvalException(err)
}
return leftValue.FloatValue() >= fromValue.FloatValue() &&
leftValue.FloatValue() <= toValue.FloatValue(), nil
}
func (f *comparisonExpr) between(leftValue, fromValue, toValue *Value) (*Value, error) {
result, err := f.computeBetween(leftValue, fromValue, toValue)
if err != nil {
return nil, err
}
return NewBool(result), nil
}
func (f *comparisonExpr) notBetween(leftValue, fromValue, toValue *Value) (*Value, error) {
result, err := f.computeBetween(leftValue, fromValue, toValue)
if err != nil {
return nil, err
}
return NewBool(!result), nil
}
func (f *comparisonExpr) computeIn(leftValue, rightValue *Value) (found bool, err error) {
if rightValue.Type() != Array {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to Array", f, rightValue.Type())
return false, errExternalEvalException(err)
}
values := rightValue.ArrayValue()
for i := range values {
found, err = equal(leftValue, values[i])
if err != nil {
return false, err
}
if found {
return true, nil
}
}
return false, nil
}
func (f *comparisonExpr) in(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeIn(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(result), nil
}
func (f *comparisonExpr) notIn(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeIn(leftValue, rightValue)
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return nil, errExternalEvalException(err)
}
return NewBool(!result), nil
}
func (f *comparisonExpr) computeLike(leftValue, rightValue *Value) (matched bool, err error) {
if leftValue.Type() != String {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to string", f, leftValue.Type())
return false, errExternalEvalException(err)
}
if rightValue.Type() != String {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to string", f, rightValue.Type())
return false, errExternalEvalException(err)
}
matched, err = regexp.MatchString(rightValue.StringValue(), leftValue.StringValue())
if err != nil {
err = fmt.Errorf("%v: %v", f, err)
return false, errExternalEvalException(err)
}
return matched, nil
}
func (f *comparisonExpr) like(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeLike(leftValue, rightValue)
if err != nil {
return nil, err
}
return NewBool(result), nil
}
func (f *comparisonExpr) notLike(leftValue, rightValue *Value) (*Value, error) {
result, err := f.computeLike(leftValue, rightValue)
if err != nil {
return nil, err
}
return NewBool(!result), nil
}
func (f *comparisonExpr) compute(leftValue, rightValue, toValue *Value) (*Value, error) {
switch f.operator {
case Equal:
return f.equal(leftValue, rightValue)
case NotEqual:
return f.notEqual(leftValue, rightValue)
case LessThan:
return f.lessThan(leftValue, rightValue)
case GreaterThan:
return f.greaterThan(leftValue, rightValue)
case LessThanEqual:
return f.lessThanEqual(leftValue, rightValue)
case GreaterThanEqual:
return f.greaterThanEqual(leftValue, rightValue)
case Between:
return f.between(leftValue, rightValue, toValue)
case In:
return f.in(leftValue, rightValue)
case Like:
return f.like(leftValue, rightValue)
case NotBetween:
return f.notBetween(leftValue, rightValue, toValue)
case NotIn:
return f.notIn(leftValue, rightValue)
case NotLike:
return f.notLike(leftValue, rightValue)
}
panic(fmt.Errorf("unexpected expression %v", f))
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *comparisonExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
var toValue *Value
if f.to != nil {
toValue, err = f.to.Eval(record)
if err != nil {
return nil, err
}
}
if f.funcType == aggregateFunction {
return nil, nil
}
return f.compute(leftValue, rightValue, toValue)
}
// AggregateValue - returns aggregated value.
func (f *comparisonExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
var toValue *Value
if f.to != nil {
toValue, err = f.to.AggregateValue()
if err != nil {
return nil, err
}
}
return f.compute(leftValue, rightValue, toValue)
}
// Type - returns comparisonFunction or aggregateFunction type.
func (f *comparisonExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *comparisonExpr) ReturnType() Type {
return Bool
}
// newComparisonExpr - creates new comparison function.
func newComparisonExpr(operator ComparisonOperator, funcs ...Expr) (*comparisonExpr, error) {
funcType := comparisonFunction
switch operator {
case Equal, NotEqual:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if !right.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for equality check", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for equality check", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case LessThan, GreaterThan, LessThanEqual, GreaterThanEqual:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if !right.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v; not number", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case In, NotIn:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isBaseKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v is incompatible for equality check", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
right := funcs[1]
if right.ReturnType() != Array {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v is incompatible for equality check", operator, right, right.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case column, Array, function, arithmeticFunction, comparisonFunction, logicalFunction, record:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case Array, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case Like, NotLike:
if len(funcs) != 2 {
panic(fmt.Sprintf("exactly two arguments are expected, but found %v", len(funcs)))
}
left := funcs[0]
if !left.ReturnType().isStringKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not string", operator, left, left.ReturnType())
return nil, errLikeInvalidInputs(err)
}
right := funcs[1]
if !right.ReturnType().isStringKind() {
err := fmt.Errorf("operator %v: right side expression %v evaluate to %v, not string", operator, right, right.ReturnType())
return nil, errLikeInvalidInputs(err)
}
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case String, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch right.Type() {
case String, aggregateFunction:
default:
err := fmt.Errorf("operator %v: right side expression %v return type %v is incompatible for aggregate evaluation", operator, right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: right,
operator: operator,
funcType: funcType,
}, nil
case Between, NotBetween:
if len(funcs) != 3 {
panic(fmt.Sprintf("too many values in funcs %v", funcs))
}
left := funcs[0]
if !left.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: left side expression %v evaluate to %v, not number", operator, left, left.ReturnType())
return nil, errInvalidDataType(err)
}
from := funcs[1]
if !from.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: from expression %v evaluate to %v, not number", operator, from, from.ReturnType())
return nil, errInvalidDataType(err)
}
to := funcs[2]
if !to.ReturnType().isNumberKind() {
err := fmt.Errorf("operator %v: to expression %v evaluate to %v, not number", operator, to, to.ReturnType())
return nil, errInvalidDataType(err)
}
if left.Type() == aggregateFunction || from.Type() == aggregateFunction || to.Type() == aggregateFunction {
funcType = aggregateFunction
switch left.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: left side expression %v return type %v is incompatible for aggregate evaluation", operator, left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch from.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: from expression %v return type %v is incompatible for aggregate evaluation", operator, from, from.Type())
return nil, errUnsupportedSQLOperation(err)
}
switch to.Type() {
case Int, Float, aggregateFunction:
default:
err := fmt.Errorf("operator %v: to expression %v return type %v is incompatible for aggregate evaluation", operator, to, to.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &comparisonExpr{
left: left,
right: from,
to: to,
operator: operator,
funcType: funcType,
}, nil
case IsNull, IsNotNull:
if len(funcs) != 1 {
panic(fmt.Sprintf("too many values in funcs %v", funcs))
}
if funcs[0].Type() == aggregateFunction {
funcType = aggregateFunction
}
if operator == IsNull {
operator = Equal
} else {
operator = NotEqual
}
return &comparisonExpr{
left: funcs[0],
right: newValueExpr(NewNull()),
operator: operator,
funcType: funcType,
}, nil
}
return nil, errParseUnknownOperator(fmt.Errorf("unknown operator %v", operator))
}

215
pkg/s3select/sql/errors.go Normal file
View File

@ -0,0 +1,215 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
type s3Error struct {
code string
message string
statusCode int
cause error
}
func (err *s3Error) Cause() error {
return err.cause
}
func (err *s3Error) ErrorCode() string {
return err.code
}
func (err *s3Error) ErrorMessage() string {
return err.message
}
func (err *s3Error) HTTPStatusCode() int {
return err.statusCode
}
func (err *s3Error) Error() string {
return err.message
}
func errUnsupportedSQLStructure(err error) *s3Error {
return &s3Error{
code: "UnsupportedSqlStructure",
message: "Encountered an unsupported SQL structure. Check the SQL Reference.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedSelect(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedSelect",
message: "The SQL expression contains an unsupported use of SELECT.",
statusCode: 400,
cause: err,
}
}
func errParseAsteriskIsNotAloneInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseAsteriskIsNotAloneInSelectList",
message: "Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseInvalidContextForWildcardInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseInvalidContextForWildcardInSelectList",
message: "Invalid use of * in SELECT list in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInvalidDataType(err error) *s3Error {
return &s3Error{
code: "InvalidDataType",
message: "The SQL expression contains an invalid data type.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedFunction(err error) *s3Error {
return &s3Error{
code: "UnsupportedFunction",
message: "Encountered an unsupported SQL function.",
statusCode: 400,
cause: err,
}
}
func errParseNonUnaryAgregateFunctionCall(err error) *s3Error {
return &s3Error{
code: "ParseNonUnaryAgregateFunctionCall",
message: "Only one argument is supported for aggregate functions in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errIncorrectSQLFunctionArgumentType(err error) *s3Error {
return &s3Error{
code: "IncorrectSqlFunctionArgumentType",
message: "Incorrect type of arguments in function call in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidArguments(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidArguments",
message: "Incorrect number of arguments in the function call in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedSQLOperation(err error) *s3Error {
return &s3Error{
code: "UnsupportedSqlOperation",
message: "Encountered an unsupported SQL operation.",
statusCode: 400,
cause: err,
}
}
func errParseUnknownOperator(err error) *s3Error {
return &s3Error{
code: "ParseUnknownOperator",
message: "The SQL expression contains an invalid operator.",
statusCode: 400,
cause: err,
}
}
func errLikeInvalidInputs(err error) *s3Error {
return &s3Error{
code: "LikeInvalidInputs",
message: "Invalid argument given to the LIKE clause in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errExternalEvalException(err error) *s3Error {
return &s3Error{
code: "ExternalEvalException",
message: "The query cannot be evaluated. Check the file and try again.",
statusCode: 400,
cause: err,
}
}
func errValueParseFailure(err error) *s3Error {
return &s3Error{
code: "ValueParseFailure",
message: "Time stamp parse failure in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorBindingDoesNotExist(err error) *s3Error {
return &s3Error{
code: "EvaluatorBindingDoesNotExist",
message: "A column name or a path provided does not exist in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInternalError(err error) *s3Error {
return &s3Error{
code: "InternalError",
message: "Encountered an internal error.",
statusCode: 500,
cause: err,
}
}
func errParseInvalidTypeParam(err error) *s3Error {
return &s3Error{
code: "ParseInvalidTypeParam",
message: "The SQL expression contains an invalid parameter value.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedSyntax(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedSyntax",
message: "The SQL expression contains unsupported syntax.",
statusCode: 400,
cause: err,
}
}
func errInvalidKeyPath(err error) *s3Error {
return &s3Error{
code: "InvalidKeyPath",
message: "Key path in the SQL expression is invalid.",
statusCode: 400,
cause: err,
}
}

160
pkg/s3select/sql/expr.go Normal file
View File

@ -0,0 +1,160 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
)
// Expr - a SQL expression type.
type Expr interface {
AggregateValue() (*Value, error)
Eval(record Record) (*Value, error)
ReturnType() Type
Type() Type
}
// aliasExpr - aliases expression by alias.
type aliasExpr struct {
alias string
expr Expr
}
// String - returns string representation of this expression.
func (expr *aliasExpr) String() string {
return fmt.Sprintf("(%v AS %v)", expr.expr, expr.alias)
}
// Eval - evaluates underlaying expression for given record and returns evaluated result.
func (expr *aliasExpr) Eval(record Record) (*Value, error) {
return expr.expr.Eval(record)
}
// AggregateValue - returns aggregated value from underlaying expression.
func (expr *aliasExpr) AggregateValue() (*Value, error) {
return expr.expr.AggregateValue()
}
// Type - returns underlaying expression type.
func (expr *aliasExpr) Type() Type {
return expr.expr.Type()
}
// ReturnType - returns underlaying expression's return type.
func (expr *aliasExpr) ReturnType() Type {
return expr.expr.ReturnType()
}
// newAliasExpr - creates new alias expression.
func newAliasExpr(alias string, expr Expr) *aliasExpr {
return &aliasExpr{alias, expr}
}
// starExpr - asterisk (*) expression.
type starExpr struct {
}
// String - returns string representation of this expression.
func (expr *starExpr) String() string {
return "*"
}
// Eval - returns given args as map value.
func (expr *starExpr) Eval(record Record) (*Value, error) {
return newRecordValue(record), nil
}
// AggregateValue - returns nil value.
func (expr *starExpr) AggregateValue() (*Value, error) {
return nil, nil
}
// Type - returns record type.
func (expr *starExpr) Type() Type {
return record
}
// ReturnType - returns record as return type.
func (expr *starExpr) ReturnType() Type {
return record
}
// newStarExpr - returns new asterisk (*) expression.
func newStarExpr() *starExpr {
return &starExpr{}
}
type valueExpr struct {
value *Value
}
func (expr *valueExpr) String() string {
return expr.value.String()
}
func (expr *valueExpr) Eval(record Record) (*Value, error) {
return expr.value, nil
}
func (expr *valueExpr) AggregateValue() (*Value, error) {
return expr.value, nil
}
func (expr *valueExpr) Type() Type {
return expr.value.Type()
}
func (expr *valueExpr) ReturnType() Type {
return expr.value.Type()
}
func newValueExpr(value *Value) *valueExpr {
return &valueExpr{value: value}
}
type columnExpr struct {
name string
}
func (expr *columnExpr) String() string {
return expr.name
}
func (expr *columnExpr) Eval(record Record) (*Value, error) {
value, err := record.Get(expr.name)
if err != nil {
return nil, errEvaluatorBindingDoesNotExist(err)
}
return value, nil
}
func (expr *columnExpr) AggregateValue() (*Value, error) {
return nil, nil
}
func (expr *columnExpr) Type() Type {
return column
}
func (expr *columnExpr) ReturnType() Type {
return column
}
func newColumnExpr(columnName string) *columnExpr {
return &columnExpr{name: columnName}
}

View File

@ -0,0 +1,550 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
"time"
)
// FuncName - SQL function name.
type FuncName string
const (
// Avg - aggregate SQL function AVG().
Avg FuncName = "AVG"
// Count - aggregate SQL function COUNT().
Count FuncName = "COUNT"
// Max - aggregate SQL function MAX().
Max FuncName = "MAX"
// Min - aggregate SQL function MIN().
Min FuncName = "MIN"
// Sum - aggregate SQL function SUM().
Sum FuncName = "SUM"
// Coalesce - conditional SQL function COALESCE().
Coalesce FuncName = "COALESCE"
// NullIf - conditional SQL function NULLIF().
NullIf FuncName = "NULLIF"
// ToTimestamp - conversion SQL function TO_TIMESTAMP().
ToTimestamp FuncName = "TO_TIMESTAMP"
// UTCNow - date SQL function UTCNOW().
UTCNow FuncName = "UTCNOW"
// CharLength - string SQL function CHAR_LENGTH().
CharLength FuncName = "CHAR_LENGTH"
// CharacterLength - string SQL function CHARACTER_LENGTH() same as CHAR_LENGTH().
CharacterLength FuncName = "CHARACTER_LENGTH"
// Lower - string SQL function LOWER().
Lower FuncName = "LOWER"
// Substring - string SQL function SUBSTRING().
Substring FuncName = "SUBSTRING"
// Trim - string SQL function TRIM().
Trim FuncName = "TRIM"
// Upper - string SQL function UPPER().
Upper FuncName = "UPPER"
// DateAdd FuncName = "DATE_ADD"
// DateDiff FuncName = "DATE_DIFF"
// Extract FuncName = "EXTRACT"
// ToString FuncName = "TO_STRING"
// Cast FuncName = "CAST" // CAST('2007-04-05T14:30Z' AS TIMESTAMP)
)
func isAggregateFuncName(s string) bool {
switch FuncName(s) {
case Avg, Count, Max, Min, Sum:
return true
}
return false
}
func callForNumber(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if !value.Type().isNumber() {
err := fmt.Errorf("%v evaluated to %v; not to number", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
func callForInt(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if value.Type() != Int {
err := fmt.Errorf("%v evaluated to %v; not to int", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
func callForString(f Expr, record Record) (*Value, error) {
value, err := f.Eval(record)
if err != nil {
return nil, err
}
if value.Type() != String {
err := fmt.Errorf("%v evaluated to %v; not to string", f, value.Type())
return nil, errExternalEvalException(err)
}
return value, nil
}
// funcExpr - SQL function.
type funcExpr struct {
args []Expr
name FuncName
sumValue float64
countValue int64
maxValue float64
minValue float64
}
// String - returns string representation of this function.
func (f *funcExpr) String() string {
var argStrings []string
for _, arg := range f.args {
argStrings = append(argStrings, fmt.Sprintf("%v", arg))
}
return fmt.Sprintf("%v(%v)", f.name, strings.Join(argStrings, ","))
}
func (f *funcExpr) sum(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
f.sumValue += value.FloatValue()
f.countValue++
return nil, nil
}
func (f *funcExpr) count(record Record) (*Value, error) {
value, err := f.args[0].Eval(record)
if err != nil {
return nil, err
}
if value.valueType != Null {
f.countValue++
}
return nil, nil
}
func (f *funcExpr) max(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
v := value.FloatValue()
if v > f.maxValue {
f.maxValue = v
}
return nil, nil
}
func (f *funcExpr) min(record Record) (*Value, error) {
value, err := callForNumber(f.args[0], record)
if err != nil {
return nil, err
}
v := value.FloatValue()
if v < f.minValue {
f.minValue = v
}
return nil, nil
}
func (f *funcExpr) charLength(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewInt(int64(len(value.StringValue()))), nil
}
func (f *funcExpr) trim(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.TrimSpace(value.StringValue())), nil
}
func (f *funcExpr) lower(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.ToLower(value.StringValue())), nil
}
func (f *funcExpr) upper(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
return NewString(strings.ToUpper(value.StringValue())), nil
}
func (f *funcExpr) substring(record Record) (*Value, error) {
stringValue, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
offsetValue, err := callForInt(f.args[1], record)
if err != nil {
return nil, err
}
var lengthValue *Value
if len(f.args) == 3 {
lengthValue, err = callForInt(f.args[2], record)
if err != nil {
return nil, err
}
}
value := stringValue.StringValue()
offset := int(offsetValue.FloatValue())
if offset < 0 || offset > len(value) {
offset = 0
}
length := len(value)
if lengthValue != nil {
length = int(lengthValue.FloatValue())
if length < 0 || length > len(value) {
length = len(value)
}
}
return NewString(value[offset:length]), nil
}
func (f *funcExpr) coalesce(record Record) (*Value, error) {
values := make([]*Value, len(f.args))
var err error
for i := range f.args {
values[i], err = f.args[i].Eval(record)
if err != nil {
return nil, err
}
}
for i := range values {
if values[i].Type() != Null {
return values[i], nil
}
}
return values[0], nil
}
func (f *funcExpr) nullIf(record Record) (*Value, error) {
value1, err := f.args[0].Eval(record)
if err != nil {
return nil, err
}
value2, err := f.args[1].Eval(record)
if err != nil {
return nil, err
}
result, err := equal(value1, value2)
if err != nil {
return nil, err
}
if result {
return NewNull(), nil
}
return value1, nil
}
func (f *funcExpr) toTimeStamp(record Record) (*Value, error) {
value, err := callForString(f.args[0], record)
if err != nil {
return nil, err
}
t, err := time.Parse(time.RFC3339, value.StringValue())
if err != nil {
err := fmt.Errorf("%v: value '%v': %v", f, value, err)
return nil, errValueParseFailure(err)
}
return NewTime(t), nil
}
func (f *funcExpr) utcNow(record Record) (*Value, error) {
return NewTime(time.Now().UTC()), nil
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *funcExpr) Eval(record Record) (*Value, error) {
switch f.name {
case Avg, Sum:
return f.sum(record)
case Count:
return f.count(record)
case Max:
return f.max(record)
case Min:
return f.min(record)
case Coalesce:
return f.coalesce(record)
case NullIf:
return f.nullIf(record)
case ToTimestamp:
return f.toTimeStamp(record)
case UTCNow:
return f.utcNow(record)
case Substring:
return f.substring(record)
case CharLength, CharacterLength:
return f.charLength(record)
case Trim:
return f.trim(record)
case Lower:
return f.lower(record)
case Upper:
return f.upper(record)
}
panic(fmt.Sprintf("unsupported aggregate function %v", f.name))
}
// AggregateValue - returns aggregated value.
func (f *funcExpr) AggregateValue() (*Value, error) {
switch f.name {
case Avg:
return NewFloat(f.sumValue / float64(f.countValue)), nil
case Count:
return NewInt(f.countValue), nil
case Max:
return NewFloat(f.maxValue), nil
case Min:
return NewFloat(f.minValue), nil
case Sum:
return NewFloat(f.sumValue), nil
}
err := fmt.Errorf("%v is not aggreate function", f)
return nil, errExternalEvalException(err)
}
// Type - returns Function or aggregateFunction type.
func (f *funcExpr) Type() Type {
switch f.name {
case Avg, Count, Max, Min, Sum:
return aggregateFunction
}
return function
}
// ReturnType - returns respective primitive type depending on SQL function.
func (f *funcExpr) ReturnType() Type {
switch f.name {
case Avg, Max, Min, Sum:
return Float
case Count:
return Int
case CharLength, CharacterLength, Trim, Lower, Upper, Substring:
return String
case ToTimestamp, UTCNow:
return Timestamp
case Coalesce, NullIf:
return column
}
return function
}
// newFuncExpr - creates new SQL function.
func newFuncExpr(funcName FuncName, funcs ...Expr) (*funcExpr, error) {
switch funcName {
case Avg, Max, Min, Sum:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errParseNonUnaryAgregateFunctionCall(err)
}
if !funcs[0].ReturnType().isNumberKind() {
err := fmt.Errorf("%v(): argument %v evaluate to %v, not number", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Count:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errParseNonUnaryAgregateFunctionCall(err)
}
switch funcs[0].ReturnType() {
case Null, Bool, Int, Float, String, Timestamp, column, record:
default:
err := fmt.Errorf("%v(): argument %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case CharLength, CharacterLength, Trim, Lower, Upper, ToTimestamp:
if len(funcs) != 1 {
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isStringKind() {
err := fmt.Errorf("%v(): argument %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Coalesce:
if len(funcs) < 1 {
err := fmt.Errorf("%v(): one or more argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
for i := range funcs {
if !funcs[i].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-%v %v evaluate to %v is incompatible", funcName, i+1, funcs[i], funcs[i].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case NullIf:
if len(funcs) != 2 {
err := fmt.Errorf("%v(): exactly two arguments expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if !funcs[1].ReturnType().isBaseKind() {
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v is incompatible", funcName, funcs[1], funcs[1].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case UTCNow:
if len(funcs) != 0 {
err := fmt.Errorf("%v(): no argument expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
case Substring:
if len(funcs) < 2 || len(funcs) > 3 {
err := fmt.Errorf("%v(): exactly two or three arguments expected; got %v", funcName, len(funcs))
return nil, errEvaluatorInvalidArguments(err)
}
if !funcs[0].ReturnType().isStringKind() {
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if !funcs[1].ReturnType().isIntKind() {
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v, not int", funcName, funcs[1], funcs[1].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
if len(funcs) > 2 {
if !funcs[2].ReturnType().isIntKind() {
err := fmt.Errorf("%v(): argument-3 %v evaluate to %v, not int", funcName, funcs[2], funcs[2].ReturnType())
return nil, errIncorrectSQLFunctionArgumentType(err)
}
}
return &funcExpr{
args: funcs,
name: funcName,
}, nil
}
return nil, errUnsupportedFunction(fmt.Errorf("unknown function name %v", funcName))
}

View File

@ -0,0 +1,336 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import "fmt"
// andExpr - logical AND function.
type andExpr struct {
left Expr
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *andExpr) String() string {
return fmt.Sprintf("(%v AND %v)", f.left, f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *andExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
_, err = f.right.Eval(record)
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// AggregateValue - returns aggregated value.
func (f *andExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if !leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *andExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *andExpr) ReturnType() Type {
return Bool
}
// newAndExpr - creates new AND logical function.
func newAndExpr(left, right Expr) (*andExpr, error) {
if !left.ReturnType().isBoolKind() {
err := fmt.Errorf("operator AND: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator AND: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
if left.Type() == column {
err := fmt.Errorf("operator AND: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
if right.Type() == column {
err := fmt.Errorf("operator AND: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &andExpr{
left: left,
right: right,
funcType: funcType,
}, nil
}
// orExpr - logical OR function.
type orExpr struct {
left Expr
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *orExpr) String() string {
return fmt.Sprintf("(%v OR %v)", f.left, f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *orExpr) Eval(record Record) (*Value, error) {
leftValue, err := f.left.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
_, err = f.right.Eval(record)
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// AggregateValue - returns aggregated value.
func (f *orExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
leftValue, err := f.left.AggregateValue()
if err != nil {
return nil, err
}
if leftValue.Type() != Bool {
err := fmt.Errorf("%v: left side expression evaluated to %v; not to bool", f, leftValue.Type())
return nil, errExternalEvalException(err)
}
if leftValue.BoolValue() {
return leftValue, nil
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return rightValue, nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *orExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *orExpr) ReturnType() Type {
return Bool
}
// newOrExpr - creates new OR logical function.
func newOrExpr(left, right Expr) (*orExpr, error) {
if !left.ReturnType().isBoolKind() {
err := fmt.Errorf("operator OR: left side expression %v evaluate to %v, not bool", left, left.ReturnType())
return nil, errInvalidDataType(err)
}
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator OR: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if left.Type() == aggregateFunction || right.Type() == aggregateFunction {
funcType = aggregateFunction
if left.Type() == column {
err := fmt.Errorf("operator OR: left side expression %v return type %v is incompatible for aggregate evaluation", left, left.Type())
return nil, errUnsupportedSQLOperation(err)
}
if right.Type() == column {
err := fmt.Errorf("operator OR: right side expression %v return type %v is incompatible for aggregate evaluation", right, right.Type())
return nil, errUnsupportedSQLOperation(err)
}
}
return &orExpr{
left: left,
right: right,
funcType: funcType,
}, nil
}
// notExpr - logical NOT function.
type notExpr struct {
right Expr
funcType Type
}
// String - returns string representation of this function.
func (f *notExpr) String() string {
return fmt.Sprintf("(%v)", f.right)
}
// Call - evaluates this function for given arg values and returns result as Value.
func (f *notExpr) Eval(record Record) (*Value, error) {
rightValue, err := f.right.Eval(record)
if err != nil {
return nil, err
}
if f.funcType == aggregateFunction {
return nil, nil
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(!rightValue.BoolValue()), nil
}
// AggregateValue - returns aggregated value.
func (f *notExpr) AggregateValue() (*Value, error) {
if f.funcType != aggregateFunction {
err := fmt.Errorf("%v is not aggreate expression", f)
return nil, errExternalEvalException(err)
}
rightValue, err := f.right.AggregateValue()
if err != nil {
return nil, err
}
if rightValue.Type() != Bool {
err := fmt.Errorf("%v: right side expression evaluated to %v; not to bool", f, rightValue.Type())
return nil, errExternalEvalException(err)
}
return NewBool(!rightValue.BoolValue()), nil
}
// Type - returns logicalFunction or aggregateFunction type.
func (f *notExpr) Type() Type {
return f.funcType
}
// ReturnType - returns Bool as return type.
func (f *notExpr) ReturnType() Type {
return Bool
}
// newNotExpr - creates new NOT logical function.
func newNotExpr(right Expr) (*notExpr, error) {
if !right.ReturnType().isBoolKind() {
err := fmt.Errorf("operator NOT: right side expression %v evaluate to %v; not bool", right, right.ReturnType())
return nil, errInvalidDataType(err)
}
funcType := logicalFunction
if right.Type() == aggregateFunction {
funcType = aggregateFunction
}
return &notExpr{
right: right,
funcType: funcType,
}, nil
}

View File

@ -0,0 +1,25 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
// Record - is a type containing columns and their values.
type Record interface {
Get(name string) (*Value, error)
Set(name string, value *Value) error
MarshalCSV(fieldDelimiter rune) ([]byte, error)
MarshalJSON() ([]byte, error)
}

529
pkg/s3select/sql/sql.go Normal file
View File

@ -0,0 +1,529 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"fmt"
"strings"
"github.com/xwb1989/sqlparser"
)
func getColumnName(colName *sqlparser.ColName) string {
columnName := colName.Qualifier.Name.String()
if qualifier := colName.Qualifier.Qualifier.String(); qualifier != "" {
columnName = qualifier + "." + columnName
}
if columnName == "" {
columnName = colName.Name.String()
} else {
columnName = columnName + "." + colName.Name.String()
}
return columnName
}
func newLiteralExpr(parserExpr sqlparser.Expr, tableAlias string) (Expr, error) {
switch parserExpr.(type) {
case *sqlparser.NullVal:
return newValueExpr(NewNull()), nil
case sqlparser.BoolVal:
return newValueExpr(NewBool((bool(parserExpr.(sqlparser.BoolVal))))), nil
case *sqlparser.SQLVal:
sqlValue := parserExpr.(*sqlparser.SQLVal)
value, err := NewValue(sqlValue)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
case *sqlparser.ColName:
columnName := getColumnName(parserExpr.(*sqlparser.ColName))
if tableAlias != "" {
if !strings.HasPrefix(columnName, tableAlias+".") {
err := fmt.Errorf("column name %v does not start with table alias %v", columnName, tableAlias)
return nil, errInvalidKeyPath(err)
}
columnName = strings.TrimPrefix(columnName, tableAlias+".")
}
return newColumnExpr(columnName), nil
case sqlparser.ValTuple:
var valueType Type
var values []*Value
for i, valExpr := range parserExpr.(sqlparser.ValTuple) {
sqlVal, ok := valExpr.(*sqlparser.SQLVal)
if !ok {
return nil, errParseInvalidTypeParam(fmt.Errorf("value %v in Tuple should be primitive value", i+1))
}
val, err := NewValue(sqlVal)
if err != nil {
return nil, err
}
if i == 0 {
valueType = val.Type()
} else if valueType != val.Type() {
return nil, errParseInvalidTypeParam(fmt.Errorf("mixed value type is not allowed in Tuple"))
}
values = append(values, val)
}
return newValueExpr(NewArray(values)), nil
}
return nil, nil
}
func isExprToComparisonExpr(parserExpr *sqlparser.IsExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func rangeCondToComparisonFunc(parserExpr *sqlparser.RangeCond, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
fromExpr, err := newExpr(parserExpr.From, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
toExpr, err := newExpr(parserExpr.To, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, fromExpr, toExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !fromExpr.Type().isBase() || !toExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toComparisonExpr(parserExpr *sqlparser.ComparisonExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, rightExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toArithExpr(parserExpr *sqlparser.BinaryExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newArithExpr(ArithOperator(parserExpr.Operator), leftExpr, rightExpr)
if err != nil {
return nil, err
}
if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toFuncExpr(parserExpr *sqlparser.FuncExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
funcName := strings.ToUpper(parserExpr.Name.String())
if !isSelectExpr && isAggregateFuncName(funcName) {
return nil, errUnsupportedSQLOperation(fmt.Errorf("%v() must be used in select expression", funcName))
}
funcs, aggregatedExprFound, err := newSelectExprs(parserExpr.Exprs, tableAlias)
if err != nil {
return nil, err
}
if aggregatedExprFound {
return nil, errIncorrectSQLFunctionArgumentType(fmt.Errorf("%v(): aggregated expression must not be used as argument", funcName))
}
return newFuncExpr(FuncName(funcName), funcs...)
}
func toAndExpr(parserExpr *sqlparser.AndExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newAndExpr(leftExpr, rightExpr)
if err != nil {
return nil, err
}
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toOrExpr(parserExpr *sqlparser.OrExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newOrExpr(leftExpr, rightExpr)
if err != nil {
return nil, err
}
if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func toNotExpr(parserExpr *sqlparser.NotExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
rightExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
if err != nil {
return nil, err
}
f, err := newNotExpr(rightExpr)
if err != nil {
return nil, err
}
if rightExpr.Type() != Bool {
return f, nil
}
value, err := f.Eval(nil)
if err != nil {
return nil, err
}
return newValueExpr(value), nil
}
func newExpr(parserExpr sqlparser.Expr, tableAlias string, isSelectExpr bool) (Expr, error) {
f, err := newLiteralExpr(parserExpr, tableAlias)
if err != nil {
return nil, err
}
if f != nil {
return f, nil
}
switch parserExpr.(type) {
case *sqlparser.ParenExpr:
return newExpr(parserExpr.(*sqlparser.ParenExpr).Expr, tableAlias, isSelectExpr)
case *sqlparser.IsExpr:
return isExprToComparisonExpr(parserExpr.(*sqlparser.IsExpr), tableAlias, isSelectExpr)
case *sqlparser.RangeCond:
return rangeCondToComparisonFunc(parserExpr.(*sqlparser.RangeCond), tableAlias, isSelectExpr)
case *sqlparser.ComparisonExpr:
return toComparisonExpr(parserExpr.(*sqlparser.ComparisonExpr), tableAlias, isSelectExpr)
case *sqlparser.BinaryExpr:
return toArithExpr(parserExpr.(*sqlparser.BinaryExpr), tableAlias, isSelectExpr)
case *sqlparser.FuncExpr:
return toFuncExpr(parserExpr.(*sqlparser.FuncExpr), tableAlias, isSelectExpr)
case *sqlparser.AndExpr:
return toAndExpr(parserExpr.(*sqlparser.AndExpr), tableAlias, isSelectExpr)
case *sqlparser.OrExpr:
return toOrExpr(parserExpr.(*sqlparser.OrExpr), tableAlias, isSelectExpr)
case *sqlparser.NotExpr:
return toNotExpr(parserExpr.(*sqlparser.NotExpr), tableAlias, isSelectExpr)
}
return nil, errParseUnsupportedSyntax(fmt.Errorf("unknown expression type %T; %v", parserExpr, parserExpr))
}
func newSelectExprs(parserSelectExprs []sqlparser.SelectExpr, tableAlias string) ([]Expr, bool, error) {
var funcs []Expr
starExprFound := false
aggregatedExprFound := false
for _, selectExpr := range parserSelectExprs {
switch selectExpr.(type) {
case *sqlparser.AliasedExpr:
if starExprFound {
return nil, false, errParseAsteriskIsNotAloneInSelectList(nil)
}
aliasedExpr := selectExpr.(*sqlparser.AliasedExpr)
f, err := newExpr(aliasedExpr.Expr, tableAlias, true)
if err != nil {
return nil, false, err
}
if f.Type() == aggregateFunction {
if !aggregatedExprFound {
aggregatedExprFound = true
if len(funcs) > 0 {
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
}
}
} else if aggregatedExprFound {
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
}
alias := aliasedExpr.As.String()
if alias != "" {
f = newAliasExpr(alias, f)
}
funcs = append(funcs, f)
case *sqlparser.StarExpr:
if starExprFound {
err := fmt.Errorf("only single star expression allowed")
return nil, false, errParseInvalidContextForWildcardInSelectList(err)
}
starExprFound = true
funcs = append(funcs, newStarExpr())
default:
return nil, false, errParseUnsupportedSyntax(fmt.Errorf("unknown select expression %v", selectExpr))
}
}
return funcs, aggregatedExprFound, nil
}
// Select - SQL Select statement.
type Select struct {
tableName string
tableAlias string
selectExprs []Expr
aggregatedExprFound bool
whereExpr Expr
}
// TableAlias - returns table alias name.
func (statement *Select) TableAlias() string {
return statement.tableAlias
}
// IsSelectAll - returns whether '*' is used in select expression or not.
func (statement *Select) IsSelectAll() bool {
if len(statement.selectExprs) == 1 {
_, ok := statement.selectExprs[0].(*starExpr)
return ok
}
return false
}
// IsAggregated - returns whether aggregated functions are used in select expression or not.
func (statement *Select) IsAggregated() bool {
return statement.aggregatedExprFound
}
// AggregateResult - returns aggregate result as record.
func (statement *Select) AggregateResult(output Record) error {
if !statement.aggregatedExprFound {
return nil
}
for i, expr := range statement.selectExprs {
value, err := expr.AggregateValue()
if err != nil {
return err
}
if value == nil {
return errInternalError(fmt.Errorf("%v returns <nil> for AggregateValue()", expr))
}
name := fmt.Sprintf("_%v", i+1)
if _, ok := expr.(*aliasExpr); ok {
name = expr.(*aliasExpr).alias
}
if err = output.Set(name, value); err != nil {
return errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
}
}
return nil
}
// Eval - evaluates this Select expressions for given record.
func (statement *Select) Eval(input, output Record) (Record, error) {
if statement.whereExpr != nil {
value, err := statement.whereExpr.Eval(input)
if err != nil {
return nil, err
}
if value == nil || value.valueType != Bool {
err = fmt.Errorf("WHERE expression %v returns invalid bool value %v", statement.whereExpr, value)
return nil, errInternalError(err)
}
if !value.BoolValue() {
return nil, nil
}
}
// Call selectExprs
for i, expr := range statement.selectExprs {
value, err := expr.Eval(input)
if err != nil {
return nil, err
}
if statement.aggregatedExprFound {
continue
}
name := fmt.Sprintf("_%v", i+1)
switch expr.(type) {
case *starExpr:
return value.recordValue(), nil
case *aliasExpr:
name = expr.(*aliasExpr).alias
case *columnExpr:
name = expr.(*columnExpr).name
}
if err = output.Set(name, value); err != nil {
return nil, errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
}
}
return output, nil
}
// NewSelect - creates new Select by parsing sql.
func NewSelect(sql string) (*Select, error) {
stmt, err := sqlparser.Parse(sql)
if err != nil {
return nil, errUnsupportedSQLStructure(err)
}
selectStmt, ok := stmt.(*sqlparser.Select)
if !ok {
return nil, errParseUnsupportedSelect(fmt.Errorf("unsupported SQL statement %v", sql))
}
var tableName, tableAlias string
for _, fromExpr := range selectStmt.From {
tableExpr := fromExpr.(*sqlparser.AliasedTableExpr)
tableName = tableExpr.Expr.(sqlparser.TableName).Name.String()
tableAlias = tableExpr.As.String()
}
selectExprs, aggregatedExprFound, err := newSelectExprs(selectStmt.SelectExprs, tableAlias)
if err != nil {
return nil, err
}
var whereExpr Expr
if selectStmt.Where != nil {
whereExpr, err = newExpr(selectStmt.Where.Expr, tableAlias, false)
if err != nil {
return nil, err
}
}
return &Select{
tableName: tableName,
tableAlias: tableAlias,
selectExprs: selectExprs,
aggregatedExprFound: aggregatedExprFound,
whereExpr: whereExpr,
}, nil
}

118
pkg/s3select/sql/type.go Normal file
View File

@ -0,0 +1,118 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
// Type - value type.
type Type string
const (
// Null - represents NULL value type.
Null Type = "null"
// Bool - represents boolean value type.
Bool Type = "bool"
// Int - represents integer value type.
Int Type = "int"
// Float - represents floating point value type.
Float Type = "float"
// String - represents string value type.
String Type = "string"
// Timestamp - represents time value type.
Timestamp Type = "timestamp"
// Array - represents array of values where each value type is one of above.
Array Type = "array"
column Type = "column"
record Type = "record"
function Type = "function"
aggregateFunction Type = "aggregatefunction"
arithmeticFunction Type = "arithmeticfunction"
comparisonFunction Type = "comparisonfunction"
logicalFunction Type = "logicalfunction"
// Integer Type = "integer" // Same as Int
// Decimal Type = "decimal" // Same as Float
// Numeric Type = "numeric" // Same as Float
)
func (t Type) isBase() bool {
switch t {
case Null, Bool, Int, Float, String, Timestamp:
return true
}
return false
}
func (t Type) isBaseKind() bool {
switch t {
case Null, Bool, Int, Float, String, Timestamp, column:
return true
}
return false
}
func (t Type) isNumber() bool {
switch t {
case Int, Float:
return true
}
return false
}
func (t Type) isNumberKind() bool {
switch t {
case Int, Float, column:
return true
}
return false
}
func (t Type) isIntKind() bool {
switch t {
case Int, column:
return true
}
return false
}
func (t Type) isBoolKind() bool {
switch t {
case Bool, column:
return true
}
return false
}
func (t Type) isStringKind() bool {
switch t {
case String, column:
return true
}
return false
}

223
pkg/s3select/sql/value.go Normal file
View File

@ -0,0 +1,223 @@
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/xwb1989/sqlparser"
)
// Value - represents any primitive value of bool, int, float, string and time.
type Value struct {
value interface{}
valueType Type
}
// String - represents value as string.
func (value *Value) String() string {
if value.value == nil {
if value.valueType == Null {
return "NULL"
}
return "<nil>"
}
switch value.valueType {
case String:
return fmt.Sprintf("'%v'", value.value)
case Array:
var valueStrings []string
for _, v := range value.value.([]*Value) {
valueStrings = append(valueStrings, fmt.Sprintf("%v", v))
}
return fmt.Sprintf("(%v)", strings.Join(valueStrings, ","))
}
return fmt.Sprintf("%v", value.value)
}
// CSVString - encodes to CSV string.
func (value *Value) CSVString() string {
return fmt.Sprintf("%v", value.value)
}
// MarshalJSON - encodes to JSON data.
func (value *Value) MarshalJSON() ([]byte, error) {
return json.Marshal(value.value)
}
// BoolValue - returns underlying bool value. It panics if value is not Bool type.
func (value *Value) BoolValue() bool {
if value.valueType == Bool {
return value.value.(bool)
}
panic(fmt.Sprintf("requested bool value but found %T type", value.value))
}
// IntValue - returns underlying int value. It panics if value is not Int type.
func (value *Value) IntValue() int64 {
if value.valueType == Int {
return value.value.(int64)
}
panic(fmt.Sprintf("requested int value but found %T type", value.value))
}
// FloatValue - returns underlying int/float value as float64. It panics if value is not Int/Float type.
func (value *Value) FloatValue() float64 {
switch value.valueType {
case Int:
return float64(value.value.(int64))
case Float:
return value.value.(float64)
}
panic(fmt.Sprintf("requested float value but found %T type", value.value))
}
// StringValue - returns underlying string value. It panics if value is not String type.
func (value *Value) StringValue() string {
if value.valueType == String {
return value.value.(string)
}
panic(fmt.Sprintf("requested string value but found %T type", value.value))
}
// TimeValue - returns underlying time value. It panics if value is not Timestamp type.
func (value *Value) TimeValue() time.Time {
if value.valueType == Timestamp {
return value.value.(time.Time)
}
panic(fmt.Sprintf("requested time value but found %T type", value.value))
}
// ArrayValue - returns underlying value array. It panics if value is not Array type.
func (value *Value) ArrayValue() []*Value {
if value.valueType == Array {
return value.value.([]*Value)
}
panic(fmt.Sprintf("requested array value but found %T type", value.value))
}
func (value *Value) recordValue() Record {
if value.valueType == record {
return value.value.(Record)
}
panic(fmt.Sprintf("requested record value but found %T type", value.value))
}
// Type - returns value type.
func (value *Value) Type() Type {
return value.valueType
}
// Value - returns underneath value interface.
func (value *Value) Value() interface{} {
return value.value
}
// NewNull - creates new null value.
func NewNull() *Value {
return &Value{nil, Null}
}
// NewBool - creates new Bool value of b.
func NewBool(b bool) *Value {
return &Value{b, Bool}
}
// NewInt - creates new Int value of i.
func NewInt(i int64) *Value {
return &Value{i, Int}
}
// NewFloat - creates new Float value of f.
func NewFloat(f float64) *Value {
return &Value{f, Float}
}
// NewString - creates new Sring value of s.
func NewString(s string) *Value {
return &Value{s, String}
}
// NewTime - creates new Time value of t.
func NewTime(t time.Time) *Value {
return &Value{t, Timestamp}
}
// NewArray - creates new Array value of values.
func NewArray(values []*Value) *Value {
return &Value{values, Array}
}
func newRecordValue(r Record) *Value {
return &Value{r, record}
}
// NewValue - creates new Value from SQLVal v.
func NewValue(v *sqlparser.SQLVal) (*Value, error) {
switch v.Type {
case sqlparser.StrVal:
return NewString(string(v.Val)), nil
case sqlparser.IntVal:
i64, err := strconv.ParseInt(string(v.Val), 10, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.FloatVal:
f64, err := strconv.ParseFloat(string(v.Val), 64)
if err != nil {
return nil, err
}
return NewFloat(f64), nil
case sqlparser.HexNum: // represented as 0xDD
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.HexVal: // represented as X'0DD'
i64, err := strconv.ParseInt(string(v.Val), 16, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.BitVal: // represented as B'00'
i64, err := strconv.ParseInt(string(v.Val), 2, 64)
if err != nil {
return nil, err
}
return NewInt(i64), nil
case sqlparser.ValArg:
// FIXME: the format is unknown and not sure how to handle it.
}
return nil, fmt.Errorf("unknown SQL value %v; %v ", v, v.Type)
}

Binary file not shown.

View File

@ -0,0 +1,642 @@
// +build ignore
/*
* Minio Cloud Storage, (C) 2019 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package s3select
///////////////////////////////////////////////////////////////////////
//
// Validation errors.
//
///////////////////////////////////////////////////////////////////////
func errExpressionTooLong(err error) *s3Error {
return &s3Error{
code: "ExpressionTooLong",
message: "The SQL expression is too long: The maximum byte-length for the SQL expression is 256 KB.",
statusCode: 400,
cause: err,
}
}
func errColumnTooLong(err error) *s3Error {
return &s3Error{
code: "ColumnTooLong",
message: "The length of a column in the result is greater than maxCharsPerColumn of 1 MB.",
statusCode: 400,
cause: err,
}
}
func errOverMaxColumn(err error) *s3Error {
return &s3Error{
code: "OverMaxColumn",
message: "The number of columns in the result is greater than the maximum allowable number of columns.",
statusCode: 400,
cause: err,
}
}
func errOverMaxRecordSize(err error) *s3Error {
return &s3Error{
code: "OverMaxRecordSize",
message: "The length of a record in the input or result is greater than maxCharsPerRecord of 1 MB.",
statusCode: 400,
cause: err,
}
}
func errInvalidColumnIndex(err error) *s3Error {
return &s3Error{
code: "InvalidColumnIndex",
message: "Column index in the SQL expression is invalid.",
statusCode: 400,
cause: err,
}
}
func errInvalidTextEncoding(err error) *s3Error {
return &s3Error{
code: "InvalidTextEncoding",
message: "Invalid encoding type. Only UTF-8 encoding is supported.",
statusCode: 400,
cause: err,
}
}
func errInvalidTableAlias(err error) *s3Error {
return &s3Error{
code: "InvalidTableAlias",
message: "The SQL expression contains an invalid table alias.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedSyntax(err error) *s3Error {
return &s3Error{
code: "UnsupportedSyntax",
message: "Encountered invalid syntax.",
statusCode: 400,
cause: err,
}
}
func errAmbiguousFieldName(err error) *s3Error {
return &s3Error{
code: "AmbiguousFieldName",
message: "Field name matches to multiple fields in the file. Check the SQL expression and the file, and try again.",
statusCode: 400,
cause: err,
}
}
func errIntegerOverflow(err error) *s3Error {
return &s3Error{
code: "IntegerOverflow",
message: "Integer overflow or underflow in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errIllegalSQLFunctionArgument(err error) *s3Error {
return &s3Error{
code: "IllegalSqlFunctionArgument",
message: "Illegal argument was used in the SQL function.",
statusCode: 400,
cause: err,
}
}
func errMultipleDataSourcesUnsupported(err error) *s3Error {
return &s3Error{
code: "MultipleDataSourcesUnsupported",
message: "Multiple data sources are not supported.",
statusCode: 400,
cause: err,
}
}
func errMissingHeaders(err error) *s3Error {
return &s3Error{
code: "MissingHeaders",
message: "Some headers in the query are missing from the file. Check the file and try again.",
statusCode: 400,
cause: err,
}
}
func errUnrecognizedFormatException(err error) *s3Error {
return &s3Error{
code: "UnrecognizedFormatException",
message: "Encountered an invalid record type.",
statusCode: 400,
cause: err,
}
}
//////////////////////////////////////////////////////////////////////////////////////
//
// SQL parsing errors.
//
//////////////////////////////////////////////////////////////////////////////////////
func errLexerInvalidChar(err error) *s3Error {
return &s3Error{
code: "LexerInvalidChar",
message: "The SQL expression contains an invalid character.",
statusCode: 400,
cause: err,
}
}
func errLexerInvalidOperator(err error) *s3Error {
return &s3Error{
code: "LexerInvalidOperator",
message: "The SQL expression contains an invalid literal.",
statusCode: 400,
cause: err,
}
}
func errLexerInvalidLiteral(err error) *s3Error {
return &s3Error{
code: "LexerInvalidLiteral",
message: "The SQL expression contains an invalid operator.",
statusCode: 400,
cause: err,
}
}
func errLexerInvalidIONLiteral(err error) *s3Error {
return &s3Error{
code: "LexerInvalidIONLiteral",
message: "The SQL expression contains an invalid operator.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedDatePart(err error) *s3Error {
return &s3Error{
code: "ParseExpectedDatePart",
message: "Did not find the expected date part in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedKeyword(err error) *s3Error {
return &s3Error{
code: "ParseExpectedKeyword",
message: "Did not find the expected keyword in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedTokenType(err error) *s3Error {
return &s3Error{
code: "ParseExpectedTokenType",
message: "Did not find the expected token in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpected2TokenTypes(err error) *s3Error {
return &s3Error{
code: "ParseExpected2TokenTypes",
message: "Did not find the expected token in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedNumber(err error) *s3Error {
return &s3Error{
code: "ParseExpectedNumber",
message: "Did not find the expected number in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedRightParenBuiltinFunctionCall(err error) *s3Error {
return &s3Error{
code: "ParseExpectedRightParenBuiltinFunctionCall",
message: "Did not find the expected right parenthesis character in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedTypeName(err error) *s3Error {
return &s3Error{
code: "ParseExpectedTypeName",
message: "Did not find the expected type name in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedWhenClause(err error) *s3Error {
return &s3Error{
code: "ParseExpectedWhenClause",
message: "Did not find the expected WHEN clause in the SQL expression. CASE is not supported.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedToken(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedToken",
message: "The SQL expression contains an unsupported token.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedLiteralsGroupBy(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedLiteralsGroupBy",
message: "The SQL expression contains an unsupported use of GROUP BY.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedMember(err error) *s3Error {
return &s3Error{
code: "ParseExpectedMember",
message: "The SQL expression contains an unsupported use of MEMBER.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedCase(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedCase",
message: "The SQL expression contains an unsupported use of CASE.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedCaseClause(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedCaseClause",
message: "The SQL expression contains an unsupported use of CASE.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedAlias(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedAlias",
message: "The SQL expression contains an unsupported use of ALIAS.",
statusCode: 400,
cause: err,
}
}
func errParseInvalidPathComponent(err error) *s3Error {
return &s3Error{
code: "ParseInvalidPathComponent",
message: "The SQL expression contains an invalid path component.",
statusCode: 400,
cause: err,
}
}
func errParseMissingIdentAfterAt(err error) *s3Error {
return &s3Error{
code: "ParseMissingIdentAfterAt",
message: "Did not find the expected identifier after the @ symbol in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseUnexpectedOperator(err error) *s3Error {
return &s3Error{
code: "ParseUnexpectedOperator",
message: "The SQL expression contains an unexpected operator.",
statusCode: 400,
cause: err,
}
}
func errParseUnexpectedTerm(err error) *s3Error {
return &s3Error{
code: "ParseUnexpectedTerm",
message: "The SQL expression contains an unexpected term.",
statusCode: 400,
cause: err,
}
}
func errParseUnexpectedToken(err error) *s3Error {
return &s3Error{
code: "ParseUnexpectedToken",
message: "The SQL expression contains an unexpected token.",
statusCode: 400,
cause: err,
}
}
func errParseUnExpectedKeyword(err error) *s3Error {
return &s3Error{
code: "ParseUnExpectedKeyword",
message: "The SQL expression contains an unexpected keyword.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedExpression(err error) *s3Error {
return &s3Error{
code: "ParseExpectedExpression",
message: "Did not find the expected SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedLeftParenAfterCast(err error) *s3Error {
return &s3Error{
code: "ParseExpectedLeftParenAfterCast",
message: "Did not find the expected left parenthesis after CAST in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedLeftParenValueConstructor(err error) *s3Error {
return &s3Error{
code: "ParseExpectedLeftParenValueConstructor",
message: "Did not find expected the left parenthesis in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedLeftParenBuiltinFunctionCall(err error) *s3Error {
return &s3Error{
code: "ParseExpectedLeftParenBuiltinFunctionCall",
message: "Did not find the expected left parenthesis in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedArgumentDelimiter(err error) *s3Error {
return &s3Error{
code: "ParseExpectedArgumentDelimiter",
message: "Did not find the expected argument delimiter in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseCastArity(err error) *s3Error {
return &s3Error{
code: "ParseCastArity",
message: "The SQL expression CAST has incorrect arity.",
statusCode: 400,
cause: err,
}
}
func errParseEmptySelect(err error) *s3Error {
return &s3Error{
code: "ParseEmptySelect",
message: "The SQL expression contains an empty SELECT.",
statusCode: 400,
cause: err,
}
}
func errParseSelectMissingFrom(err error) *s3Error {
return &s3Error{
code: "ParseSelectMissingFrom",
message: "The SQL expression contains a missing FROM after SELECT list.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedIdentForGroupName(err error) *s3Error {
return &s3Error{
code: "ParseExpectedIdentForGroupName",
message: "GROUP is not supported in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedIdentForAlias(err error) *s3Error {
return &s3Error{
code: "ParseExpectedIdentForAlias",
message: "Did not find the expected identifier for the alias in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseUnsupportedCallWithStar(err error) *s3Error {
return &s3Error{
code: "ParseUnsupportedCallWithStar",
message: "Only COUNT with (*) as a parameter is supported in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseMalformedJoin(err error) *s3Error {
return &s3Error{
code: "ParseMalformedJoin",
message: "JOIN is not supported in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseExpectedIdentForAt(err error) *s3Error {
return &s3Error{
code: "ParseExpectedIdentForAt",
message: "Did not find the expected identifier for AT name in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errParseCannotMixSqbAndWildcardInSelectList(err error) *s3Error {
return &s3Error{
code: "ParseCannotMixSqbAndWildcardInSelectList",
message: "Cannot mix [] and * in the same expression in a SELECT list in SQL expression.",
statusCode: 400,
cause: err,
}
}
//////////////////////////////////////////////////////////////////////////////////////
//
// CAST() related errors.
//
//////////////////////////////////////////////////////////////////////////////////////
func errCastFailed(err error) *s3Error {
return &s3Error{
code: "CastFailed",
message: "Attempt to convert from one data type to another using CAST failed in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errInvalidCast(err error) *s3Error {
return &s3Error{
code: "InvalidCast",
message: "Attempt to convert from one data type to another using CAST failed in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidTimestampFormatPattern(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidTimestampFormatPattern",
message: "Invalid time stamp format string in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidTimestampFormatPatternAdditionalFieldsRequired(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidTimestampFormatPattern",
message: "Time stamp format pattern requires additional fields in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidTimestampFormatPatternSymbolForParsing(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidTimestampFormatPatternSymbolForParsing",
message: "Time stamp format pattern contains a valid format symbol that cannot be applied to time stamp parsing in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorTimestampFormatPatternDuplicateFields(err error) *s3Error {
return &s3Error{
code: "EvaluatorTimestampFormatPatternDuplicateFields",
message: "Time stamp format pattern contains multiple format specifiers representing the time stamp field in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorTimestampFormatPatternHourClockAmPmMismatch(err error) *s3Error {
return &s3Error{
code: "EvaluatorTimestampFormatPatternHourClockAmPmMismatch",
message: "Time stamp format pattern contains a 12-hour hour of day format symbol but doesn't also contain an AM/PM field, or it contains a 24-hour hour of day format specifier and contains an AM/PM field in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorUnterminatedTimestampFormatPatternToken(err error) *s3Error {
return &s3Error{
code: "EvaluatorUnterminatedTimestampFormatPatternToken",
message: "Time stamp format pattern contains unterminated token in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidTimestampFormatPatternToken(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidTimestampFormatPatternToken",
message: "Time stamp format pattern contains an invalid token in the SQL expression.",
statusCode: 400,
cause: err,
}
}
func errEvaluatorInvalidTimestampFormatPatternSymbol(err error) *s3Error {
return &s3Error{
code: "EvaluatorInvalidTimestampFormatPatternSymbol",
message: "Time stamp format pattern contains an invalid symbol in the SQL expression.",
statusCode: 400,
cause: err,
}
}
////////////////////////////////////////////////////////////////////////
//
// Generic S3 HTTP handler errors.
//
////////////////////////////////////////////////////////////////////////
func errBusy(err error) *s3Error {
return &s3Error{
code: "Busy",
message: "The service is unavailable. Please retry.",
statusCode: 503,
cause: err,
}
}
func errUnauthorizedAccess(err error) *s3Error {
return &s3Error{
code: "UnauthorizedAccess",
message: "You are not authorized to perform this operation",
statusCode: 401,
cause: err,
}
}
func errEmptyRequestBody(err error) *s3Error {
return &s3Error{
code: "EmptyRequestBody",
message: "Request body cannot be empty.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedRangeHeader(err error) *s3Error {
return &s3Error{
code: "UnsupportedRangeHeader",
message: "Range header is not supported for this operation.",
statusCode: 400,
cause: err,
}
}
func errUnsupportedStorageClass(err error) *s3Error {
return &s3Error{
code: "UnsupportedStorageClass",
message: "Encountered an invalid storage class. Only STANDARD, STANDARD_IA, and ONEZONE_IA storage classes are supported.",
statusCode: 400,
cause: err,
}
}

View File

@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
const (
UNKNOWN_APPLICATION_EXCEPTION = 0
UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE_EXCEPTION = 2
WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
)
var defaultApplicationExceptionMessage = map[int32]string{
UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception",
UNKNOWN_METHOD: "unknown method",
INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type",
WRONG_METHOD_NAME: "wrong method name",
BAD_SEQUENCE_ID: "bad sequence ID",
MISSING_RESULT: "missing result",
INTERNAL_ERROR: "unknown internal error",
PROTOCOL_ERROR: "unknown protocol error",
}
// Application level Thrift exception
type TApplicationException interface {
TException
TypeId() int32
Read(iprot TProtocol) error
Write(oprot TProtocol) error
}
type tApplicationException struct {
message string
type_ int32
}
func (e tApplicationException) Error() string {
if e.message != "" {
return e.message
}
return defaultApplicationExceptionMessage[e.type_]
}
func NewTApplicationException(type_ int32, message string) TApplicationException {
return &tApplicationException{message, type_}
}
func (p *tApplicationException) TypeId() int32 {
return p.type_
}
func (p *tApplicationException) Read(iprot TProtocol) error {
// TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
if err != nil {
return err
}
message := ""
type_ := int32(UNKNOWN_APPLICATION_EXCEPTION)
for {
_, ttype, id, err := iprot.ReadFieldBegin()
if err != nil {
return err
}
if ttype == STOP {
break
}
switch id {
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
return err
}
p.message = message
p.type_ = type_
return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {
err = oprot.WriteStructBegin("TApplicationException")
if len(p.Error()) > 0 {
err = oprot.WriteFieldBegin("message", STRING, 1)
if err != nil {
return
}
err = oprot.WriteString(p.Error())
if err != nil {
return
}
err = oprot.WriteFieldEnd()
if err != nil {
return
}
}
err = oprot.WriteFieldBegin("type", I32, 2)
if err != nil {
return
}
err = oprot.WriteI32(p.type_)
if err != nil {
return
}
err = oprot.WriteFieldEnd()
if err != nil {
return
}
err = oprot.WriteFieldStop()
if err != nil {
return
}
err = oprot.WriteStructEnd()
return
}

View File

@ -0,0 +1,509 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
)
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
reader io.Reader
writer io.Writer
strictRead bool
strictWrite bool
buffer [64]byte
}
type TBinaryProtocolFactory struct {
strictRead bool
strictWrite bool
}
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
return NewTBinaryProtocol(t, false, true)
}
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
if et, ok := t.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(t)
}
p.reader = p.trans
p.writer = p.trans
return p
}
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
return NewTBinaryProtocolFactory(false, true)
}
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
}
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
}
/**
* Writing Methods
*/
func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
if p.strictWrite {
version := uint32(VERSION_1) | uint32(typeId)
e := p.WriteI32(int32(version))
if e != nil {
return e
}
e = p.WriteString(name)
if e != nil {
return e
}
e = p.WriteI32(seqId)
return e
} else {
e := p.WriteString(name)
if e != nil {
return e
}
e = p.WriteByte(int8(typeId))
if e != nil {
return e
}
e = p.WriteI32(seqId)
return e
}
return nil
}
func (p *TBinaryProtocol) WriteMessageEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteStructBegin(name string) error {
return nil
}
func (p *TBinaryProtocol) WriteStructEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
e := p.WriteByte(int8(typeId))
if e != nil {
return e
}
e = p.WriteI16(id)
return e
}
func (p *TBinaryProtocol) WriteFieldEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteFieldStop() error {
e := p.WriteByte(STOP)
return e
}
func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
e := p.WriteByte(int8(keyType))
if e != nil {
return e
}
e = p.WriteByte(int8(valueType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteMapEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteListEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteSetEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteBool(value bool) error {
if value {
return p.WriteByte(1)
}
return p.WriteByte(0)
}
func (p *TBinaryProtocol) WriteByte(value int8) error {
e := p.trans.WriteByte(byte(value))
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI16(value int16) error {
v := p.buffer[0:2]
binary.BigEndian.PutUint16(v, uint16(value))
_, e := p.writer.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI32(value int32) error {
v := p.buffer[0:4]
binary.BigEndian.PutUint32(v, uint32(value))
_, e := p.writer.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI64(value int64) error {
v := p.buffer[0:8]
binary.BigEndian.PutUint64(v, uint64(value))
_, err := p.writer.Write(v)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteDouble(value float64) error {
return p.WriteI64(int64(math.Float64bits(value)))
}
func (p *TBinaryProtocol) WriteString(value string) error {
e := p.WriteI32(int32(len(value)))
if e != nil {
return e
}
_, err := p.trans.WriteString(value)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteBinary(value []byte) error {
e := p.WriteI32(int32(len(value)))
if e != nil {
return e
}
_, err := p.writer.Write(value)
return NewTProtocolException(err)
}
/**
* Reading methods
*/
func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
size, e := p.ReadI32()
if e != nil {
return "", typeId, 0, NewTProtocolException(e)
}
if size < 0 {
typeId = TMessageType(size & 0x0ff)
version := int64(int64(size) & VERSION_MASK)
if version != VERSION_1 {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
}
name, e = p.ReadString()
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
seqId, e = p.ReadI32()
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
return name, typeId, seqId, nil
}
if p.strictRead {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
}
name, e2 := p.readStringBody(size)
if e2 != nil {
return name, typeId, seqId, e2
}
b, e3 := p.ReadByte()
if e3 != nil {
return name, typeId, seqId, e3
}
typeId = TMessageType(b)
seqId, e4 := p.ReadI32()
if e4 != nil {
return name, typeId, seqId, e4
}
return name, typeId, seqId, nil
}
func (p *TBinaryProtocol) ReadMessageEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
return
}
func (p *TBinaryProtocol) ReadStructEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
t, err := p.ReadByte()
typeId = TType(t)
if err != nil {
return name, typeId, seqId, err
}
if t != STOP {
seqId, err = p.ReadI16()
}
return name, typeId, seqId, err
}
func (p *TBinaryProtocol) ReadFieldEnd() error {
return nil
}
var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
k, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
kType = TType(k)
v, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
vType = TType(v)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return kType, vType, size, nil
}
func (p *TBinaryProtocol) ReadMapEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return
}
func (p *TBinaryProtocol) ReadListEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return elemType, size, nil
}
func (p *TBinaryProtocol) ReadSetEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadBool() (bool, error) {
b, e := p.ReadByte()
v := true
if b != 1 {
v = false
}
return v, e
}
func (p *TBinaryProtocol) ReadByte() (int8, error) {
v, err := p.trans.ReadByte()
return int8(v), err
}
func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
buf := p.buffer[0:2]
err = p.readAll(buf)
value = int16(binary.BigEndian.Uint16(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
buf := p.buffer[0:4]
err = p.readAll(buf)
value = int32(binary.BigEndian.Uint32(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
value = int64(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
value = math.Float64frombits(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadString() (value string, err error) {
size, e := p.ReadI32()
if e != nil {
return "", e
}
if size < 0 {
err = invalidDataLength
return
}
return p.readStringBody(size)
}
func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
size, e := p.ReadI32()
if e != nil {
return nil, e
}
if size < 0 {
return nil, invalidDataLength
}
isize := int(size)
buf := make([]byte, isize)
_, err := io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(err)
}
func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TBinaryProtocol) Transport() TTransport {
return p.origTransport
}
func (p *TBinaryProtocol) readAll(buf []byte) error {
_, err := io.ReadFull(p.reader, buf)
return NewTProtocolException(err)
}
const readLimit = 32768
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
var (
buf bytes.Buffer
e error
b []byte
)
switch {
case int(size) <= len(p.buffer):
b = p.buffer[:size] // avoids allocation for small reads
case int(size) < readLimit:
b = make([]byte, size)
default:
b = make([]byte, readLimit)
}
for size > 0 {
_, e = io.ReadFull(p.trans, b)
buf.Write(b)
if e != nil {
break
}
size -= readLimit
if size < readLimit && size > 0 {
b = b[:size]
}
}
return buf.String(), NewTProtocolException(e)
}

View File

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"context"
)
type TBufferedTransportFactory struct {
size int
}
type TBufferedTransport struct {
bufio.ReadWriter
tp TTransport
}
func (p *TBufferedTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return NewTBufferedTransport(trans, p.size), nil
}
func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory {
return &TBufferedTransportFactory{size: bufferSize}
}
func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport {
return &TBufferedTransport{
ReadWriter: bufio.ReadWriter{
Reader: bufio.NewReaderSize(trans, bufferSize),
Writer: bufio.NewWriterSize(trans, bufferSize),
},
tp: trans,
}
}
func (p *TBufferedTransport) IsOpen() bool {
return p.tp.IsOpen()
}
func (p *TBufferedTransport) Open() (err error) {
return p.tp.Open()
}
func (p *TBufferedTransport) Close() (err error) {
return p.tp.Close()
}
func (p *TBufferedTransport) Read(b []byte) (int, error) {
n, err := p.ReadWriter.Read(b)
if err != nil {
p.ReadWriter.Reader.Reset(p.tp)
}
return n, err
}
func (p *TBufferedTransport) Write(b []byte) (int, error) {
n, err := p.ReadWriter.Write(b)
if err != nil {
p.ReadWriter.Writer.Reset(p.tp)
}
return n, err
}
func (p *TBufferedTransport) Flush(ctx context.Context) error {
if err := p.ReadWriter.Flush(); err != nil {
p.ReadWriter.Writer.Reset(p.tp)
return err
}
return p.tp.Flush(ctx)
}
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {
return p.tp.RemainingBytes()
}

View File

@ -0,0 +1,85 @@
package thrift
import (
"context"
"fmt"
)
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) error
}
type TStandardClient struct {
seqId int32
iprot, oprot TProtocol
}
// TStandardClient implements TClient, and uses the standard message format for Thrift.
// It is not safe for concurrent use.
func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
return &TStandardClient{
iprot: inputProtocol,
oprot: outputProtocol,
}
}
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
return err
}
if err := args.Write(oprot); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
return err
}
return oprot.Flush(ctx)
}
func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
if err != nil {
return err
}
if method != rMethod {
return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
} else if seqId != rSeqId {
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(); err != nil {
return err
}
return &exception
} else if rTypeId != REPLY {
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
if err := result.Read(iprot); err != nil {
return err
}
return iprot.ReadMessageEnd()
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return err
}
// method is oneway
if result == nil {
return nil
}
return p.Recv(p.iprot, seqId, method, result)
}

View File

@ -0,0 +1,810 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/binary"
"fmt"
"io"
"math"
)
const (
COMPACT_PROTOCOL_ID = 0x082
COMPACT_VERSION = 1
COMPACT_VERSION_MASK = 0x1f
COMPACT_TYPE_MASK = 0x0E0
COMPACT_TYPE_BITS = 0x07
COMPACT_TYPE_SHIFT_AMOUNT = 5
)
type tCompactType byte
const (
COMPACT_BOOLEAN_TRUE = 0x01
COMPACT_BOOLEAN_FALSE = 0x02
COMPACT_BYTE = 0x03
COMPACT_I16 = 0x04
COMPACT_I32 = 0x05
COMPACT_I64 = 0x06
COMPACT_DOUBLE = 0x07
COMPACT_BINARY = 0x08
COMPACT_LIST = 0x09
COMPACT_SET = 0x0A
COMPACT_MAP = 0x0B
COMPACT_STRUCT = 0x0C
)
var (
ttypeToCompactType map[TType]tCompactType
)
func init() {
ttypeToCompactType = map[TType]tCompactType{
STOP: STOP,
BOOL: COMPACT_BOOLEAN_TRUE,
BYTE: COMPACT_BYTE,
I16: COMPACT_I16,
I32: COMPACT_I32,
I64: COMPACT_I64,
DOUBLE: COMPACT_DOUBLE,
STRING: COMPACT_BINARY,
LIST: COMPACT_LIST,
SET: COMPACT_SET,
MAP: COMPACT_MAP,
STRUCT: COMPACT_STRUCT,
}
}
type TCompactProtocolFactory struct{}
func NewTCompactProtocolFactory() *TCompactProtocolFactory {
return &TCompactProtocolFactory{}
}
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTCompactProtocol(trans)
}
type TCompactProtocol struct {
trans TRichTransport
origTransport TTransport
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField []int
lastFieldId int
// If we encounter a boolean field begin, save the TField here so it can
// have the value incorporated.
booleanFieldName string
booleanFieldId int16
booleanFieldPending bool
// If we read a field header, and it's a boolean field, save the boolean
// value here so that readBool can use it.
boolValue bool
boolValueIsNotNull bool
buffer [64]byte
}
// Create a TCompactProtocol given a TTransport
func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
p := &TCompactProtocol{origTransport: trans, lastField: []int{}}
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(trans)
}
return p
}
//
// Public Writing methods.
//
// Write a message header to the wire. Compact Protocol messages contain the
// protocol version so we can migrate forwards in the future if need be.
func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := p.writeByteDirect(COMPACT_PROTOCOL_ID)
if err != nil {
return NewTProtocolException(err)
}
err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK))
if err != nil {
return NewTProtocolException(err)
}
_, err = p.writeVarint32(seqid)
if err != nil {
return NewTProtocolException(err)
}
e := p.WriteString(name)
return e
}
func (p *TCompactProtocol) WriteMessageEnd() error { return nil }
// Write a struct begin. This doesn't actually put anything on the wire. We
// use it as an opportunity to put special placeholder markers on the field
// stack so we can get the field id deltas correct.
func (p *TCompactProtocol) WriteStructBegin(name string) error {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return nil
}
// Write a struct end. This doesn't actually put anything on the wire. We use
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
func (p *TCompactProtocol) WriteStructEnd() error {
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if typeId == BOOL {
// we want to possibly include the value, so we'll wait.
p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true
return nil
}
_, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF)
return NewTProtocolException(err)
}
// The workhorse of writeFieldBegin. It has the option of doing a
// 'type override' of the type header. This is used specifically in the
// boolean field case.
func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) {
// short lastField = lastField_.pop();
// if there's a type override, use that.
var typeToWrite byte
if typeOverride == 0xFF {
typeToWrite = byte(p.getCompactType(typeId))
} else {
typeToWrite = typeOverride
}
// check if we can use delta encoding for the field id
fieldId := int(id)
written := 0
if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 {
// write them together
err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite)
if err != nil {
return 0, err
}
} else {
// write them separate
err := p.writeByteDirect(typeToWrite)
if err != nil {
return 0, err
}
err = p.WriteI16(id)
written = 1 + 2
if err != nil {
return 0, err
}
}
p.lastFieldId = fieldId
// p.lastField.Push(field.id);
return written, nil
}
func (p *TCompactProtocol) WriteFieldEnd() error { return nil }
func (p *TCompactProtocol) WriteFieldStop() error {
err := p.writeByteDirect(STOP)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
if size == 0 {
err := p.writeByteDirect(0)
return NewTProtocolException(err)
}
_, err := p.writeVarint32(int32(size))
if err != nil {
return NewTProtocolException(err)
}
err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType)))
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapEnd() error { return nil }
// Write a list header.
func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteListEnd() error { return nil }
// Write a set header.
func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteSetEnd() error { return nil }
func (p *TCompactProtocol) WriteBool(value bool) error {
v := byte(COMPACT_BOOLEAN_FALSE)
if value {
v = byte(COMPACT_BOOLEAN_TRUE)
}
if p.booleanFieldPending {
// we haven't written the field header yet
_, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v)
p.booleanFieldPending = false
return NewTProtocolException(err)
}
// we're not part of a field, so just write the value.
err := p.writeByteDirect(v)
return NewTProtocolException(err)
}
// Write a byte. Nothing to see here!
func (p *TCompactProtocol) WriteByte(value int8) error {
err := p.writeByteDirect(byte(value))
return NewTProtocolException(err)
}
// Write an I16 as a zigzag varint.
func (p *TCompactProtocol) WriteI16(value int16) error {
_, err := p.writeVarint32(p.int32ToZigzag(int32(value)))
return NewTProtocolException(err)
}
// Write an i32 as a zigzag varint.
func (p *TCompactProtocol) WriteI32(value int32) error {
_, err := p.writeVarint32(p.int32ToZigzag(value))
return NewTProtocolException(err)
}
// Write an i64 as a zigzag varint.
func (p *TCompactProtocol) WriteI64(value int64) error {
_, err := p.writeVarint64(p.int64ToZigzag(value))
return NewTProtocolException(err)
}
// Write a double to the wire as 8 bytes.
func (p *TCompactProtocol) WriteDouble(value float64) error {
buf := p.buffer[0:8]
binary.LittleEndian.PutUint64(buf, math.Float64bits(value))
_, err := p.trans.Write(buf)
return NewTProtocolException(err)
}
// Write a string to the wire with a varint size preceding.
func (p *TCompactProtocol) WriteString(value string) error {
_, e := p.writeVarint32(int32(len(value)))
if e != nil {
return NewTProtocolException(e)
}
if len(value) > 0 {
}
_, e = p.trans.WriteString(value)
return e
}
// Write a byte array, using a varint for the size.
func (p *TCompactProtocol) WriteBinary(bin []byte) error {
_, e := p.writeVarint32(int32(len(bin)))
if e != nil {
return NewTProtocolException(e)
}
if len(bin) > 0 {
_, e = p.trans.Write(bin)
return NewTProtocolException(e)
}
return nil
}
//
// Reading methods.
//
// Read a message header.
func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
protocolId, err := p.readByteDirect()
if err != nil {
return
}
if protocolId != COMPACT_PROTOCOL_ID {
e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId)
return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e)
}
versionAndType, err := p.readByteDirect()
if err != nil {
return
}
version := versionAndType & COMPACT_VERSION_MASK
typeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS)
if version != COMPACT_VERSION {
e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version)
err = NewTProtocolExceptionWithType(BAD_VERSION, e)
return
}
seqId, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
name, err = p.ReadString()
return
}
func (p *TCompactProtocol) ReadMessageEnd() error { return nil }
// Read a struct begin. There's nothing on the wire for this, but it is our
// opportunity to push a new struct begin marker onto the field stack.
func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return
}
// Doesn't actually consume any wire data, just removes the last field for
// this struct from the field stack.
func (p *TCompactProtocol) ReadStructEnd() error {
// consume the last field we read off the wire.
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
// Read a field header off the wire.
func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
t, err := p.readByteDirect()
if err != nil {
return
}
// if it's a stop, then we can return immediately, as the struct is over.
if (t & 0x0f) == STOP {
return "", STOP, 0, nil
}
// mask off the 4 MSB of the type header. it could contain a field id delta.
modifier := int16((t & 0xf0) >> 4)
if modifier == 0 {
// not a delta. look ahead for the zigzag varint field id.
id, err = p.ReadI16()
if err != nil {
return
}
} else {
// has a delta. add the delta to the last read field id.
id = int16(p.lastFieldId) + modifier
}
typeId, e := p.getTType(tCompactType(t & 0x0f))
if e != nil {
err = NewTProtocolException(e)
return
}
// if this happens to be a boolean field, the value is encoded in the type
if p.isBoolType(t) {
// save the boolean value in a special instance variable.
p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE)
p.boolValueIsNotNull = true
}
// push the new field onto the field stack so we can keep the deltas going.
p.lastFieldId = int(id)
return
}
func (p *TCompactProtocol) ReadFieldEnd() error { return nil }
// Read a map header off the wire. If the size is zero, skip reading the key
// and value type. This means that 0-length maps will yield TMaps without the
// "correct" types.
func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
size32, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
keyAndValueType := byte(STOP)
if size != 0 {
keyAndValueType, err = p.readByteDirect()
if err != nil {
return
}
}
keyType, _ = p.getTType(tCompactType(keyAndValueType >> 4))
valueType, _ = p.getTType(tCompactType(keyAndValueType & 0xf))
return
}
func (p *TCompactProtocol) ReadMapEnd() error { return nil }
// Read a list header off the wire. If the list size is 0-14, the size will
// be packed into the element type header. If it's a longer list, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) {
size_and_type, err := p.readByteDirect()
if err != nil {
return
}
size = int((size_and_type >> 4) & 0x0f)
if size == 15 {
size2, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size2 < 0 {
err = invalidDataLength
return
}
size = int(size2)
}
elemType, e := p.getTType(tCompactType(size_and_type))
if e != nil {
err = NewTProtocolException(e)
return
}
return
}
func (p *TCompactProtocol) ReadListEnd() error { return nil }
// Read a set header off the wire. If the set size is 0-14, the size will
// be packed into the element type header. If it's a longer set, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return p.ReadListBegin()
}
func (p *TCompactProtocol) ReadSetEnd() error { return nil }
// Read a boolean off the wire. If this is a boolean field, the value should
// already have been read during readFieldBegin, so we'll just consume the
// pre-stored value. Otherwise, read a byte.
func (p *TCompactProtocol) ReadBool() (value bool, err error) {
if p.boolValueIsNotNull {
p.boolValueIsNotNull = false
return p.boolValue, nil
}
v, err := p.readByteDirect()
return v == COMPACT_BOOLEAN_TRUE, err
}
// Read a single byte off the wire. Nothing interesting here.
func (p *TCompactProtocol) ReadByte() (int8, error) {
v, err := p.readByteDirect()
if err != nil {
return 0, NewTProtocolException(err)
}
return int8(v), err
}
// Read an i16 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI16() (value int16, err error) {
v, err := p.ReadI32()
return int16(v), err
}
// Read an i32 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI32() (value int32, err error) {
v, e := p.readVarint32()
if e != nil {
return 0, NewTProtocolException(e)
}
value = p.zigzagToInt32(v)
return value, nil
}
// Read an i64 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI64() (value int64, err error) {
v, e := p.readVarint64()
if e != nil {
return 0, NewTProtocolException(e)
}
value = p.zigzagToInt64(v)
return value, nil
}
// No magic here - just read a double off the wire.
func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
longBits := p.buffer[0:8]
_, e := io.ReadFull(p.trans, longBits)
if e != nil {
return 0.0, NewTProtocolException(e)
}
return math.Float64frombits(p.bytesToUint64(longBits)), nil
}
// Reads a []byte (via readBinary), and then UTF-8 decodes it.
func (p *TCompactProtocol) ReadString() (value string, err error) {
length, e := p.readVarint32()
if e != nil {
return "", NewTProtocolException(e)
}
if length < 0 {
return "", invalidDataLength
}
if length == 0 {
return "", nil
}
var buf []byte
if length <= int32(len(p.buffer)) {
buf = p.buffer[0:length]
} else {
buf = make([]byte, length)
}
_, e = io.ReadFull(p.trans, buf)
return string(buf), NewTProtocolException(e)
}
// Read a []byte from the wire.
func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
length, e := p.readVarint32()
if e != nil {
return nil, NewTProtocolException(e)
}
if length == 0 {
return []byte{}, nil
}
if length < 0 {
return nil, invalidDataLength
}
buf := make([]byte, length)
_, e = io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(e)
}
func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TCompactProtocol) Transport() TTransport {
return p.origTransport
}
//
// Internal writing methods
//
// Abstract method for writing the start of lists and sets. List and sets on
// the wire differ only by the type indicator.
func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) {
if size <= 14 {
return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType))))
}
err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType)))
if err != nil {
return 0, err
}
m, err := p.writeVarint32(int32(size))
return 1 + m, err
}
// Write an i32 as a varint. Results in 1-5 bytes on the wire.
// TODO(pomack): make a permanent buffer like writeVarint64?
func (p *TCompactProtocol) writeVarint32(n int32) (int, error) {
i32buf := p.buffer[0:5]
idx := 0
for {
if (n & ^0x7F) == 0 {
i32buf[idx] = byte(n)
idx++
// p.writeByteDirect(byte(n));
break
// return;
} else {
i32buf[idx] = byte((n & 0x7F) | 0x80)
idx++
// p.writeByteDirect(byte(((n & 0x7F) | 0x80)));
u := uint32(n)
n = int32(u >> 7)
}
}
return p.trans.Write(i32buf[0:idx])
}
// Write an i64 as a varint. Results in 1-10 bytes on the wire.
func (p *TCompactProtocol) writeVarint64(n int64) (int, error) {
varint64out := p.buffer[0:10]
idx := 0
for {
if (n & ^0x7F) == 0 {
varint64out[idx] = byte(n)
idx++
break
} else {
varint64out[idx] = byte((n & 0x7F) | 0x80)
idx++
u := uint64(n)
n = int64(u >> 7)
}
}
return p.trans.Write(varint64out[0:idx])
}
// Convert l into a zigzag long. This allows negative numbers to be
// represented compactly as a varint.
func (p *TCompactProtocol) int64ToZigzag(l int64) int64 {
return (l << 1) ^ (l >> 63)
}
// Convert l into a zigzag long. This allows negative numbers to be
// represented compactly as a varint.
func (p *TCompactProtocol) int32ToZigzag(n int32) int32 {
return (n << 1) ^ (n >> 31)
}
func (p *TCompactProtocol) fixedUint64ToBytes(n uint64, buf []byte) {
binary.LittleEndian.PutUint64(buf, n)
}
func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) {
binary.LittleEndian.PutUint64(buf, uint64(n))
}
// Writes a byte without any possibility of all that field header nonsense.
// Used internally by other writing methods that know they need to write a byte.
func (p *TCompactProtocol) writeByteDirect(b byte) error {
return p.trans.WriteByte(b)
}
// Writes a byte without any possibility of all that field header nonsense.
func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) {
return 1, p.writeByteDirect(byte(n))
}
//
// Internal reading methods
//
// Read an i32 from the wire as a varint. The MSB of each byte is set
// if there is another byte to follow. This can read up to 5 bytes.
func (p *TCompactProtocol) readVarint32() (int32, error) {
// if the wire contains the right stuff, this will just truncate the i64 we
// read and get us the right sign.
v, err := p.readVarint64()
return int32(v), err
}
// Read an i64 from the wire as a proper varint. The MSB of each byte is set
// if there is another byte to follow. This can read up to 10 bytes.
func (p *TCompactProtocol) readVarint64() (int64, error) {
shift := uint(0)
result := int64(0)
for {
b, err := p.readByteDirect()
if err != nil {
return 0, err
}
result |= int64(b&0x7f) << shift
if (b & 0x80) != 0x80 {
break
}
shift += 7
}
return result, nil
}
// Read a byte, unlike ReadByte that reads Thrift-byte that is i8.
func (p *TCompactProtocol) readByteDirect() (byte, error) {
return p.trans.ReadByte()
}
//
// encoding helpers
//
// Convert from zigzag int to int.
func (p *TCompactProtocol) zigzagToInt32(n int32) int32 {
u := uint32(n)
return int32(u>>1) ^ -(n & 1)
}
// Convert from zigzag long to long.
func (p *TCompactProtocol) zigzagToInt64(n int64) int64 {
u := uint64(n)
return int64(u>>1) ^ -(n & 1)
}
// Note that it's important that the mask bytes are long literals,
// otherwise they'll default to ints, and when you shift an int left 56 bits,
// you just get a messed up int.
func (p *TCompactProtocol) bytesToInt64(b []byte) int64 {
return int64(binary.LittleEndian.Uint64(b))
}
// Note that it's important that the mask bytes are long literals,
// otherwise they'll default to ints, and when you shift an int left 56 bits,
// you just get a messed up int.
func (p *TCompactProtocol) bytesToUint64(b []byte) uint64 {
return binary.LittleEndian.Uint64(b)
}
//
// type testing and converting
//
func (p *TCompactProtocol) isBoolType(b byte) bool {
return (b&0x0f) == COMPACT_BOOLEAN_TRUE || (b&0x0f) == COMPACT_BOOLEAN_FALSE
}
// Given a tCompactType constant, convert it to its corresponding
// TType value.
func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
switch byte(t) & 0x0f {
case STOP:
return STOP, nil
case COMPACT_BOOLEAN_FALSE, COMPACT_BOOLEAN_TRUE:
return BOOL, nil
case COMPACT_BYTE:
return BYTE, nil
case COMPACT_I16:
return I16, nil
case COMPACT_I32:
return I32, nil
case COMPACT_I64:
return I64, nil
case COMPACT_DOUBLE:
return DOUBLE, nil
case COMPACT_BINARY:
return STRING, nil
case COMPACT_LIST:
return LIST, nil
case COMPACT_SET:
return SET, nil
case COMPACT_MAP:
return MAP, nil
case COMPACT_STRUCT:
return STRUCT, nil
}
return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.
func (p *TCompactProtocol) getCompactType(t TType) tCompactType {
return ttypeToCompactType[t]
}

View File

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
var defaultCtx = context.Background()

View File

@ -0,0 +1,270 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"log"
)
type TDebugProtocol struct {
Delegate TProtocol
LogPrefix string
}
type TDebugProtocolFactory struct {
Underlying TProtocolFactory
LogPrefix string
}
func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory {
return &TDebugProtocolFactory{
Underlying: underlying,
LogPrefix: logPrefix,
}
}
func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return &TDebugProtocol{
Delegate: t.Underlying.GetProtocol(trans),
LogPrefix: t.LogPrefix,
}
}
func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid)
log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
return err
}
func (tdp *TDebugProtocol) WriteMessageEnd() error {
err := tdp.Delegate.WriteMessageEnd()
log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteStructBegin(name string) error {
err := tdp.Delegate.WriteStructBegin(name)
log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
return err
}
func (tdp *TDebugProtocol) WriteStructEnd() error {
err := tdp.Delegate.WriteStructEnd()
log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
err := tdp.Delegate.WriteFieldBegin(name, typeId, id)
log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldEnd() error {
err := tdp.Delegate.WriteFieldEnd()
log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldStop() error {
err := tdp.Delegate.WriteFieldStop()
log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
err := tdp.Delegate.WriteMapBegin(keyType, valueType, size)
log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteMapEnd() error {
err := tdp.Delegate.WriteMapEnd()
log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteListBegin(elemType, size)
log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteListEnd() error {
err := tdp.Delegate.WriteListEnd()
log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteSetBegin(elemType, size)
log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteSetEnd() error {
err := tdp.Delegate.WriteSetEnd()
log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteBool(value bool) error {
err := tdp.Delegate.WriteBool(value)
log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteByte(value int8) error {
err := tdp.Delegate.WriteByte(value)
log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI16(value int16) error {
err := tdp.Delegate.WriteI16(value)
log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI32(value int32) error {
err := tdp.Delegate.WriteI32(value)
log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI64(value int64) error {
err := tdp.Delegate.WriteI64(value)
log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteDouble(value float64) error {
err := tdp.Delegate.WriteDouble(value)
log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteString(value string) error {
err := tdp.Delegate.WriteString(value)
log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteBinary(value []byte) error {
err := tdp.Delegate.WriteBinary(value)
log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin()
log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
return
}
func (tdp *TDebugProtocol) ReadMessageEnd() (err error) {
err = tdp.Delegate.ReadMessageEnd()
log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) {
name, err = tdp.Delegate.ReadStructBegin()
log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
return
}
func (tdp *TDebugProtocol) ReadStructEnd() (err error) {
err = tdp.Delegate.ReadStructEnd()
log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
name, typeId, id, err = tdp.Delegate.ReadFieldBegin()
log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
return
}
func (tdp *TDebugProtocol) ReadFieldEnd() (err error) {
err = tdp.Delegate.ReadFieldEnd()
log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
keyType, valueType, size, err = tdp.Delegate.ReadMapBegin()
log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
return
}
func (tdp *TDebugProtocol) ReadMapEnd() (err error) {
err = tdp.Delegate.ReadMapEnd()
log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadListBegin()
log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadListEnd() (err error) {
err = tdp.Delegate.ReadListEnd()
log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadSetBegin()
log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadSetEnd() (err error) {
err = tdp.Delegate.ReadSetEnd()
log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadBool() (value bool, err error) {
value, err = tdp.Delegate.ReadBool()
log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadByte() (value int8, err error) {
value, err = tdp.Delegate.ReadByte()
log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI16() (value int16, err error) {
value, err = tdp.Delegate.ReadI16()
log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI32() (value int32, err error) {
value, err = tdp.Delegate.ReadI32()
log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI64() (value int64, err error) {
value, err = tdp.Delegate.ReadI64()
log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) {
value, err = tdp.Delegate.ReadDouble()
log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadString() (value string, err error) {
value, err = tdp.Delegate.ReadString()
log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) {
value, err = tdp.Delegate.ReadBinary()
log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
err = tdp.Delegate.Skip(fieldType)
log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) Transport() TTransport {
return tdp.Delegate.Transport()
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
type TDeserializer struct {
Transport TTransport
Protocol TProtocol
}
func NewTDeserializer() *TDeserializer {
var transport TTransport
transport = NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
return &TDeserializer{
transport,
protocol}
}
func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
err = nil
if _, err = t.Transport.Write([]byte(s)); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
return
}
return
}
func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
err = nil
if _, err = t.Transport.Write(b); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
return
}
return
}

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
)
// Generic Thrift exception
type TException interface {
error
}
// Prepends additional information to an error without losing the Thrift exception interface
func PrependError(prepend string, err error) error {
if t, ok := err.(TTransportException); ok {
return NewTTransportException(t.TypeId(), prepend+t.Error())
}
if t, ok := err.(TProtocolException); ok {
return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error()))
}
if t, ok := err.(TApplicationException); ok {
return NewTApplicationException(t.TypeId(), prepend+t.Error())
}
return errors.New(prepend + err.Error())
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Helper class that encapsulates field metadata.
type field struct {
name string
typeId TType
id int
}
func newField(n string, t TType, i int) *field {
return &field{name: n, typeId: t, id: i}
}
func (p *field) Name() string {
if p == nil {
return ""
}
return p.name
}
func (p *field) TypeId() TType {
if p == nil {
return TType(VOID)
}
return p.typeId
}
func (p *field) Id() int {
if p == nil {
return -1
}
return p.id
}
func (p *field) String() string {
if p == nil {
return "<nil>"
}
return "<TField name:'" + p.name + "' type:" + string(p.typeId) + " field-id:" + string(p.id) + ">"
}
var ANONYMOUS_FIELD *field
type fieldSlice []field
func (p fieldSlice) Len() int {
return len(p)
}
func (p fieldSlice) Less(i, j int) bool {
return p[i].Id() < p[j].Id()
}
func (p fieldSlice) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
func init() {
ANONYMOUS_FIELD = newField("", STOP, 0)
}

View File

@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
)
const DEFAULT_MAX_LENGTH = 16384000
type TFramedTransport struct {
transport TTransport
buf bytes.Buffer
reader *bufio.Reader
frameSize uint32 //Current remaining size of the frame. if ==0 read next frame header
buffer [4]byte
maxLength uint32
}
type tFramedTransportFactory struct {
factory TTransportFactory
maxLength uint32
}
func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH}
}
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
return NewTFramedTransportMaxLength(tt, p.maxLength), nil
}
func NewTFramedTransport(transport TTransport) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH}
}
func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength}
}
func (p *TFramedTransport) Open() error {
return p.transport.Open()
}
func (p *TFramedTransport) IsOpen() bool {
return p.transport.IsOpen()
}
func (p *TFramedTransport) Close() error {
return p.transport.Close()
}
func (p *TFramedTransport) Read(buf []byte) (l int, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < uint32(len(buf)) {
frameSize := p.frameSize
tmp := make([]byte, p.frameSize)
l, err = p.Read(tmp)
copy(buf, tmp)
if err == nil {
err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf)))
return
}
}
got, err := p.reader.Read(buf)
p.frameSize = p.frameSize - uint32(got)
//sanity check
if p.frameSize < 0 {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size")
}
return got, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) ReadByte() (c byte, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < 1 {
return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", p.frameSize, 1))
}
c, err = p.reader.ReadByte()
if err == nil {
p.frameSize--
}
return
}
func (p *TFramedTransport) Write(buf []byte) (int, error) {
n, err := p.buf.Write(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) WriteByte(c byte) error {
return p.buf.WriteByte(c)
}
func (p *TFramedTransport) WriteString(s string) (n int, err error) {
return p.buf.WriteString(s)
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
size := p.buf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
if size > 0 {
if n, err := p.buf.WriteTo(p.transport); err != nil {
print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n")
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
}
err = p.transport.Flush(ctx)
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) readFrameHeader() (uint32, error) {
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
return 0, err
}
size := binary.BigEndian.Uint32(buf)
if size < 0 || size > p.maxLength {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
return size, nil
}
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.frameSize)
}

View File

@ -0,0 +1,242 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
)
// Default to using the shared http client. Library users are
// free to change this global client or specify one through
// THttpClientOptions.
var DefaultHttpClient *http.Client = http.DefaultClient
type THttpClient struct {
client *http.Client
response *http.Response
url *url.URL
requestBuffer *bytes.Buffer
header http.Header
nsecConnectTimeout int64
nsecReadTimeout int64
}
type THttpClientTransportFactory struct {
options THttpClientOptions
url string
}
func (p *THttpClientTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*THttpClient)
if ok && t.url != nil {
return NewTHttpClientWithOptions(t.url.String(), p.options)
}
}
return NewTHttpClientWithOptions(p.url, p.options)
}
type THttpClientOptions struct {
// If nil, DefaultHttpClient is used
Client *http.Client
}
func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return &THttpClientTransportFactory{url: url, options: options}
}
func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
}
buf := make([]byte, 0, 1024)
client := options.Client
if client == nil {
client = DefaultHttpClient
}
httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}}
return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil
}
func NewTHttpClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}
// Set the HTTP Header for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// httpTrans.SetHeader("User-Agent","Thrift Client 1.0")
func (p *THttpClient) SetHeader(key string, value string) {
p.header.Add(key, value)
}
// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// hdrValue := httpTrans.GetHeader("User-Agent")
func (p *THttpClient) GetHeader(key string) string {
return p.header.Get(key)
}
// Deletes the HTTP Header given a Header Key for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// httpTrans.DelHeader("User-Agent")
func (p *THttpClient) DelHeader(key string) {
p.header.Del(key)
}
func (p *THttpClient) Open() error {
// do nothing
return nil
}
func (p *THttpClient) IsOpen() bool {
return p.response != nil || p.requestBuffer != nil
}
func (p *THttpClient) closeResponse() error {
var err error
if p.response != nil && p.response.Body != nil {
// The docs specify that if keepalive is enabled and the response body is not
// read to completion the connection will never be returned to the pool and
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
io.Copy(ioutil.Discard, p.response.Body)
err = p.response.Body.Close()
}
p.response = nil
return err
}
func (p *THttpClient) Close() error {
if p.requestBuffer != nil {
p.requestBuffer.Reset()
p.requestBuffer = nil
}
return p.closeResponse()
}
func (p *THttpClient) Read(buf []byte) (int, error) {
if p.response == nil {
return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.")
}
n, err := p.response.Body.Read(buf)
if n > 0 && (err == nil || err == io.EOF) {
return n, nil
}
return n, NewTTransportExceptionFromError(err)
}
func (p *THttpClient) ReadByte() (c byte, err error) {
return readByte(p.response.Body)
}
func (p *THttpClient) Write(buf []byte) (int, error) {
n, err := p.requestBuffer.Write(buf)
return n, err
}
func (p *THttpClient) WriteByte(c byte) error {
return p.requestBuffer.WriteByte(c)
}
func (p *THttpClient) WriteString(s string) (n int, err error) {
return p.requestBuffer.WriteString(s)
}
func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
if err != nil {
return NewTTransportExceptionFromError(err)
}
req.Header = p.header
if ctx != nil {
req = req.WithContext(ctx)
}
response, err := p.client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)
}
if response.StatusCode != http.StatusOK {
// Close the response to avoid leaking file descriptors. closeResponse does
// more than just call Close(), so temporarily assign it and reuse the logic.
p.response = response
p.closeResponse()
// TODO(pomack) log bad response
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode))
}
p.response = response
return nil
}
func (p *THttpClient) RemainingBytes() (num_bytes uint64) {
len := p.response.ContentLength
if len >= 0 {
return uint64(len)
}
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}
// Deprecated: Use NewTHttpClientTransportFactory instead.
func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
// Deprecated: Use NewTHttpClientTransportFactoryWithOptions instead.
func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, options)
}
// Deprecated: Use NewTHttpClientWithOptions instead.
func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, options)
}
// Deprecated: Use NewTHttpClient instead.
func NewTHttpPostClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}

View File

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
func NewThriftHandlerFunc(processor TProcessor,
inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) {
return gz(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/x-thrift")
transport := NewStreamTransport(r.Body, w)
processor.Process(r.Context(), inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
})
}
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

View File

@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"context"
"io"
)
// StreamTransport is a Transport made of an io.Reader and/or an io.Writer
type StreamTransport struct {
io.Reader
io.Writer
isReadWriter bool
closed bool
}
type StreamTransportFactory struct {
Reader io.Reader
Writer io.Writer
isReadWriter bool
}
func (p *StreamTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*StreamTransport)
if ok {
if t.isReadWriter {
return NewStreamTransportRW(t.Reader.(io.ReadWriter)), nil
}
if t.Reader != nil && t.Writer != nil {
return NewStreamTransport(t.Reader, t.Writer), nil
}
if t.Reader != nil && t.Writer == nil {
return NewStreamTransportR(t.Reader), nil
}
if t.Reader == nil && t.Writer != nil {
return NewStreamTransportW(t.Writer), nil
}
return &StreamTransport{}, nil
}
}
if p.isReadWriter {
return NewStreamTransportRW(p.Reader.(io.ReadWriter)), nil
}
if p.Reader != nil && p.Writer != nil {
return NewStreamTransport(p.Reader, p.Writer), nil
}
if p.Reader != nil && p.Writer == nil {
return NewStreamTransportR(p.Reader), nil
}
if p.Reader == nil && p.Writer != nil {
return NewStreamTransportW(p.Writer), nil
}
return &StreamTransport{}, nil
}
func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory {
return &StreamTransportFactory{Reader: reader, Writer: writer, isReadWriter: isReadWriter}
}
func NewStreamTransport(r io.Reader, w io.Writer) *StreamTransport {
return &StreamTransport{Reader: bufio.NewReader(r), Writer: bufio.NewWriter(w)}
}
func NewStreamTransportR(r io.Reader) *StreamTransport {
return &StreamTransport{Reader: bufio.NewReader(r)}
}
func NewStreamTransportW(w io.Writer) *StreamTransport {
return &StreamTransport{Writer: bufio.NewWriter(w)}
}
func NewStreamTransportRW(rw io.ReadWriter) *StreamTransport {
bufrw := bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw))
return &StreamTransport{Reader: bufrw, Writer: bufrw, isReadWriter: true}
}
func (p *StreamTransport) IsOpen() bool {
return !p.closed
}
// implicitly opened on creation, can't be reopened once closed
func (p *StreamTransport) Open() error {
if !p.closed {
return NewTTransportException(ALREADY_OPEN, "StreamTransport already open.")
} else {
return NewTTransportException(NOT_OPEN, "cannot reopen StreamTransport.")
}
}
// Closes both the input and output streams.
func (p *StreamTransport) Close() error {
if p.closed {
return NewTTransportException(NOT_OPEN, "StreamTransport already closed.")
}
p.closed = true
closedReader := false
if p.Reader != nil {
c, ok := p.Reader.(io.Closer)
if ok {
e := c.Close()
closedReader = true
if e != nil {
return e
}
}
p.Reader = nil
}
if p.Writer != nil && (!closedReader || !p.isReadWriter) {
c, ok := p.Writer.(io.Closer)
if ok {
e := c.Close()
if e != nil {
return e
}
}
p.Writer = nil
}
return nil
}
// Flushes the underlying output stream if not null.
func (p *StreamTransport) Flush(ctx context.Context) error {
if p.Writer == nil {
return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream")
}
f, ok := p.Writer.(Flusher)
if ok {
err := f.Flush()
if err != nil {
return NewTTransportExceptionFromError(err)
}
}
return nil
}
func (p *StreamTransport) Read(c []byte) (n int, err error) {
n, err = p.Reader.Read(c)
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) ReadByte() (c byte, err error) {
f, ok := p.Reader.(io.ByteReader)
if ok {
c, err = f.ReadByte()
} else {
c, err = readByte(p.Reader)
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) Write(c []byte) (n int, err error) {
n, err = p.Writer.Write(c)
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) WriteByte(c byte) (err error) {
f, ok := p.Writer.(io.ByteWriter)
if ok {
err = f.WriteByte(c)
} else {
err = writeByte(p.Writer, c)
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) WriteString(s string) (n int, err error) {
f, ok := p.Writer.(stringWriter)
if ok {
n, err = f.WriteString(s)
} else {
n, err = p.Writer.Write([]byte(s))
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}

View File

@ -0,0 +1,584 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/base64"
"fmt"
)
const (
THRIFT_JSON_PROTOCOL_VERSION = 1
)
// for references to _ParseContext see tsimplejson_protocol.go
// JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
// suitable for parsing by scripting languages. It should not be
// confused with the full-featured TJSONProtocol.
//
type TJSONProtocol struct {
*TSimpleJSONProtocol
}
// Constructor
func NewTJSONProtocol(t TTransport) *TJSONProtocol {
v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
return v
}
// Factory
type TJSONProtocolFactory struct{}
func (p *TJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTJSONProtocol(trans)
}
func NewTJSONProtocolFactory() *TJSONProtocolFactory {
return &TJSONProtocolFactory{}
}
func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil {
return e
}
if e := p.WriteString(name); e != nil {
return e
}
if e := p.WriteByte(int8(typeId)); e != nil {
return e
}
if e := p.WriteI32(seqId); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteMessageEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteStructBegin(name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteStructEnd() error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if e := p.WriteI16(id); e != nil {
return e
}
if e := p.OutputObjectBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(typeId)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteFieldEnd() error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldStop() error { return nil }
func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(keyType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
s, e1 = p.TypeIdToString(valueType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
return e
}
return p.OutputObjectBegin()
}
func (p *TJSONProtocol) WriteMapEnd() error {
if e := p.OutputObjectEnd(); e != nil {
return e
}
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteListEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteSetEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteBool(b bool) error {
if b {
return p.WriteI32(1)
}
return p.WriteI32(0)
}
func (p *TJSONProtocol) WriteByte(b int8) error {
return p.WriteI32(int32(b))
}
func (p *TJSONProtocol) WriteI16(v int16) error {
return p.WriteI32(int32(v))
}
func (p *TJSONProtocol) WriteI32(v int32) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteI64(v int64) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteDouble(v float64) error {
return p.OutputF64(v)
}
func (p *TJSONProtocol) WriteString(v string) error {
return p.OutputString(v)
}
func (p *TJSONProtocol) WriteBinary(v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
// therefore we use base64 encoding to avoid excessive escaping/quoting
if e := p.OutputPreValue(); e != nil {
return e
}
if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
return NewTProtocolException(e)
}
writer := base64.NewEncoder(base64.StdEncoding, p.writer)
if _, e := writer.Write(v); e != nil {
p.writer.Reset(p.trans) // THRIFT-3735
return NewTProtocolException(e)
}
if e := writer.Close(); e != nil {
return NewTProtocolException(e)
}
if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
return NewTProtocolException(e)
}
return p.OutputPostValue()
}
// Reading methods.
func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
version, err := p.ReadI32()
if err != nil {
return name, typeId, seqId, err
}
if version != THRIFT_JSON_PROTOCOL_VERSION {
e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION)
return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
if name, err = p.ReadString(); err != nil {
return name, typeId, seqId, err
}
bTypeId, err := p.ReadByte()
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
if seqId, err = p.ReadI32(); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
func (p *TJSONProtocol) ReadMessageEnd() error {
err := p.ParseListEnd()
return err
}
func (p *TJSONProtocol) ReadStructBegin() (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
func (p *TJSONProtocol) ReadStructEnd() error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
b, _ := p.reader.Peek(1)
if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
return "", STOP, -1, nil
}
fieldId, err := p.ReadI16()
if err != nil {
return "", STOP, fieldId, err
}
if _, err = p.ParseObjectStart(); err != nil {
return "", STOP, fieldId, err
}
sType, err := p.ReadString()
if err != nil {
return "", STOP, fieldId, err
}
fType, err := p.StringToTypeId(sType)
return "", fType, fieldId, err
}
func (p *TJSONProtocol) ReadFieldEnd() error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
sKeyType, e := p.ReadString()
if e != nil {
return keyType, valueType, size, e
}
keyType, e = p.StringToTypeId(sKeyType)
if e != nil {
return keyType, valueType, size, e
}
// read valueType
sValueType, e := p.ReadString()
if e != nil {
return keyType, valueType, size, e
}
valueType, e = p.StringToTypeId(sValueType)
if e != nil {
return keyType, valueType, size, e
}
// read size
iSize, e := p.ReadI64()
if e != nil {
return keyType, valueType, size, e
}
size = int(iSize)
_, e = p.ParseObjectStart()
return keyType, valueType, size, e
}
func (p *TJSONProtocol) ReadMapEnd() error {
e := p.ParseObjectEnd()
if e != nil {
return e
}
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadListEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadSetEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadBool() (bool, error) {
value, err := p.ReadI32()
return (value != 0), err
}
func (p *TJSONProtocol) ReadByte() (int8, error) {
v, err := p.ReadI64()
return int8(v), err
}
func (p *TJSONProtocol) ReadI16() (int16, error) {
v, err := p.ReadI64()
return int16(v), err
}
func (p *TJSONProtocol) ReadI32() (int32, error) {
v, err := p.ReadI64()
return int32(v), err
}
func (p *TJSONProtocol) ReadI64() (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
func (p *TJSONProtocol) ReadDouble() (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
func (p *TJSONProtocol) ReadString() (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
}
f, _ := p.reader.Peek(1)
if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseStringBody()
v = value
if err != nil {
return v, err
}
} else if len(f) > 0 && f[0] == JSON_NULL[0] {
b := make([]byte, len(JSON_NULL))
_, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
if string(b) != string(JSON_NULL) {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
} else {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
}
f, _ := p.reader.Peek(1)
if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseBase64EncodedBody()
v = value
if err != nil {
return v, err
}
} else if len(f) > 0 && f[0] == JSON_NULL[0] {
b := make([]byte, len(JSON_NULL))
_, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
if string(b) != string(JSON_NULL) {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
} else {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
err = p.writer.Flush()
if err == nil {
err = p.trans.Flush(ctx)
}
return NewTProtocolException(err)
}
func (p *TJSONProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TJSONProtocol) Transport() TTransport {
return p.trans
}
func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(elemType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
if err != nil {
return VOID, size, err
}
elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
size = int(nSize)
return elemType, size, err2
}
func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
if err != nil {
return VOID, size, err
}
elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
size = int(nSize)
return elemType, size, err2
}
func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(elemType)
if e1 != nil {
return e1
}
if e := p.OutputString(s); e != nil {
return e
}
if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) {
switch byte(fieldType) {
case BOOL:
return "tf", nil
case BYTE:
return "i8", nil
case I16:
return "i16", nil
case I32:
return "i32", nil
case I64:
return "i64", nil
case DOUBLE:
return "dbl", nil
case STRING:
return "str", nil
case STRUCT:
return "rec", nil
case MAP:
return "map", nil
case SET:
return "set", nil
case LIST:
return "lst", nil
}
e := fmt.Errorf("Unknown fieldType: %d", int(fieldType))
return "", NewTProtocolExceptionWithType(INVALID_DATA, e)
}
func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
switch fieldType {
case "tf":
return TType(BOOL), nil
case "i8":
return TType(BYTE), nil
case "i16":
return TType(I16), nil
case "i32":
return TType(I32), nil
case "i64":
return TType(I64), nil
case "dbl":
return TType(DOUBLE), nil
case "str":
return TType(STRING), nil
case "rec":
return TType(STRUCT), nil
case "map":
return TType(MAP), nil
case "set":
return TType(SET), nil
case "lst":
return TType(LIST), nil
}
e := fmt.Errorf("Unknown type identifier: %s", fieldType)
return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}

View File

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
)
// Memory buffer-based implementation of the TTransport interface.
type TMemoryBuffer struct {
*bytes.Buffer
size int
}
type TMemoryBufferTransportFactory struct {
size int
}
func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*TMemoryBuffer)
if ok && t.size > 0 {
return NewTMemoryBufferLen(t.size), nil
}
}
return NewTMemoryBufferLen(p.size), nil
}
func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory {
return &TMemoryBufferTransportFactory{size: size}
}
func NewTMemoryBuffer() *TMemoryBuffer {
return &TMemoryBuffer{Buffer: &bytes.Buffer{}, size: 0}
}
func NewTMemoryBufferLen(size int) *TMemoryBuffer {
buf := make([]byte, 0, size)
return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size}
}
func (p *TMemoryBuffer) IsOpen() bool {
return true
}
func (p *TMemoryBuffer) Open() error {
return nil
}
func (p *TMemoryBuffer) Close() error {
p.Buffer.Reset()
return nil
}
// Flushing a memory buffer is a no-op
func (p *TMemoryBuffer) Flush(ctx context.Context) error {
return nil
}
func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) {
return uint64(p.Buffer.Len())
}

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Message type constants in the Thrift protocol.
type TMessageType int32
const (
INVALID_TMESSAGE_TYPE TMessageType = 0
CALL TMessageType = 1
REPLY TMessageType = 2
EXCEPTION TMessageType = 3
ONEWAY TMessageType = 4
)

View File

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"fmt"
"strings"
)
/*
TMultiplexedProtocol is a protocol-independent concrete decorator
that allows a Thrift client to communicate with a multiplexing Thrift server,
by prepending the service name to the function name during function calls.
NOTE: THIS IS NOT USED BY SERVERS. On the server, use TMultiplexedProcessor to handle request
from a multiplexing client.
This example uses a single socket transport to invoke two services:
socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
transport := thrift.NewTFramedTransport(socket)
protocol := thrift.NewTBinaryProtocolTransport(transport)
mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator")
service := Calculator.NewCalculatorClient(mp)
mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport")
service2 := WeatherReport.NewWeatherReportClient(mp2)
err := transport.Open()
if err != nil {
t.Fatal("Unable to open client socket", err)
}
fmt.Println(service.Add(2,2))
fmt.Println(service2.GetTemperature())
*/
type TMultiplexedProtocol struct {
TProtocol
serviceName string
}
const MULTIPLEXED_SEPARATOR = ":"
func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol {
return &TMultiplexedProtocol{
TProtocol: protocol,
serviceName: serviceName,
}
}
func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
if typeId == CALL || typeId == ONEWAY {
return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
} else {
return t.TProtocol.WriteMessageBegin(name, typeId, seqid)
}
}
/*
TMultiplexedProcessor is a TProcessor allowing
a single TServer to provide multiple services.
To do so, you instantiate the processor and then register additional
processors with it, as shown in the following example:
var processor = thrift.NewTMultiplexedProcessor()
firstProcessor :=
processor.RegisterProcessor("FirstService", firstProcessor)
processor.registerProcessor(
"Calculator",
Calculator.NewCalculatorProcessor(&CalculatorHandler{}),
)
processor.registerProcessor(
"WeatherReport",
WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}),
)
serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT)
if err != nil {
t.Fatal("Unable to create server socket", err)
}
server := thrift.NewTSimpleServer2(processor, serverTransport)
server.Serve();
*/
type TMultiplexedProcessor struct {
serviceProcessorMap map[string]TProcessor
DefaultProcessor TProcessor
}
func NewTMultiplexedProcessor() *TMultiplexedProcessor {
return &TMultiplexedProcessor{
serviceProcessorMap: make(map[string]TProcessor),
}
}
func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) {
t.DefaultProcessor = processor
}
func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) {
if t.serviceProcessorMap == nil {
t.serviceProcessorMap = make(map[string]TProcessor)
}
t.serviceProcessorMap[name] = processor
}
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin()
if err != nil {
return false, err
}
if typeId != CALL && typeId != ONEWAY {
return false, fmt.Errorf("Unexpected message type %v", typeId)
}
//extract the service name
v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
if len(v) != 2 {
if t.DefaultProcessor != nil {
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(ctx, smb, out)
}
return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
}
actualProcessor, ok := t.serviceProcessorMap[v[0]]
if !ok {
return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(ctx, smb, out)
}
//Protocol that use stored message for ReadMessageBegin
type storedMessageProtocol struct {
TProtocol
name string
typeId TMessageType
seqid int32
}
func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol {
return &storedMessageProtocol{protocol, name, typeId, seqid}
}
func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
return s.name, s.typeId, s.seqid, nil
}

View File

@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"math"
"strconv"
)
type Numeric interface {
Int64() int64
Int32() int32
Int16() int16
Byte() byte
Int() int
Float64() float64
Float32() float32
String() string
isNull() bool
}
type numeric struct {
iValue int64
dValue float64
sValue string
isNil bool
}
var (
INFINITY Numeric
NEGATIVE_INFINITY Numeric
NAN Numeric
ZERO Numeric
NUMERIC_NULL Numeric
)
func NewNumericFromDouble(dValue float64) Numeric {
if math.IsInf(dValue, 1) {
return INFINITY
}
if math.IsInf(dValue, -1) {
return NEGATIVE_INFINITY
}
if math.IsNaN(dValue) {
return NAN
}
iValue := int64(dValue)
sValue := strconv.FormatFloat(dValue, 'g', 10, 64)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI64(iValue int64) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI32(iValue int32) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
isNil := false
return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromString(sValue string) Numeric {
if sValue == INFINITY.String() {
return INFINITY
}
if sValue == NEGATIVE_INFINITY.String() {
return NEGATIVE_INFINITY
}
if sValue == NAN.String() {
return NAN
}
iValue, _ := strconv.ParseInt(sValue, 10, 64)
dValue, _ := strconv.ParseFloat(sValue, 64)
isNil := len(sValue) == 0
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromJSONString(sValue string, isNull bool) Numeric {
if isNull {
return NewNullNumeric()
}
if sValue == JSON_INFINITY {
return INFINITY
}
if sValue == JSON_NEGATIVE_INFINITY {
return NEGATIVE_INFINITY
}
if sValue == JSON_NAN {
return NAN
}
iValue, _ := strconv.ParseInt(sValue, 10, 64)
dValue, _ := strconv.ParseFloat(sValue, 64)
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNull}
}
func NewNullNumeric() Numeric {
return &numeric{iValue: 0, dValue: 0.0, sValue: "", isNil: true}
}
func (p *numeric) Int64() int64 {
return p.iValue
}
func (p *numeric) Int32() int32 {
return int32(p.iValue)
}
func (p *numeric) Int16() int16 {
return int16(p.iValue)
}
func (p *numeric) Byte() byte {
return byte(p.iValue)
}
func (p *numeric) Int() int {
return int(p.iValue)
}
func (p *numeric) Float64() float64 {
return p.dValue
}
func (p *numeric) Float32() float32 {
return float32(p.dValue)
}
func (p *numeric) String() string {
return p.sValue
}
func (p *numeric) isNull() bool {
return p.isNil
}
func init() {
INFINITY = &numeric{iValue: 0, dValue: math.Inf(1), sValue: "Infinity", isNil: false}
NEGATIVE_INFINITY = &numeric{iValue: 0, dValue: math.Inf(-1), sValue: "-Infinity", isNil: false}
NAN = &numeric{iValue: 0, dValue: math.NaN(), sValue: "NaN", isNil: false}
ZERO = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: false}
NUMERIC_NULL = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: true}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
///////////////////////////////////////////////////////////////////////////////
// This file is home to helpers that convert from various base types to
// respective pointer types. This is necessary because Go does not permit
// references to constants, nor can a pointer type to base type be allocated
// and initialized in a single expression.
//
// E.g., this is not allowed:
//
// var ip *int = &5
//
// But this *is* allowed:
//
// func IntPtr(i int) *int { return &i }
// var ip *int = IntPtr(5)
//
// Since pointers to base types are commonplace as [optional] fields in
// exported thrift structs, we factor such helpers here.
///////////////////////////////////////////////////////////////////////////////
func Float32Ptr(v float32) *float32 { return &v }
func Float64Ptr(v float64) *float64 { return &v }
func IntPtr(v int) *int { return &v }
func Int32Ptr(v int32) *int32 { return &v }
func Int64Ptr(v int64) *int64 { return &v }
func StringPtr(v string) *string { return &v }
func Uint32Ptr(v uint32) *uint32 { return &v }
func Uint64Ptr(v uint64) *uint64 { return &v }
func BoolPtr(v bool) *bool { return &v }
func ByteSlicePtr(v []byte) *[]byte { return &v }

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(ctx context.Context, in, out TProtocol) (bool, TException)
}
type TProcessorFunction interface {
Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}
// The default processor factory just returns a singleton
// instance.
type TProcessorFactory interface {
GetProcessor(trans TTransport) TProcessor
}
type tProcessorFactory struct {
processor TProcessor
}
func NewTProcessorFactory(p TProcessor) TProcessorFactory {
return &tProcessorFactory{processor: p}
}
func (p *tProcessorFactory) GetProcessor(trans TTransport) TProcessor {
return p.processor
}
/**
* The default processor factory just returns a singleton
* instance.
*/
type TProcessorFunctionFactory interface {
GetProcessorFunction(trans TTransport) TProcessorFunction
}
type tProcessorFunctionFactory struct {
processor TProcessorFunction
}
func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactory {
return &tProcessorFunctionFactory{processor: p}
}
func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction {
return p.processor
}

View File

@ -0,0 +1,179 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"errors"
"fmt"
)
const (
VERSION_MASK = 0xffff0000
VERSION_1 = 0x80010000
)
type TProtocol interface {
WriteMessageBegin(name string, typeId TMessageType, seqid int32) error
WriteMessageEnd() error
WriteStructBegin(name string) error
WriteStructEnd() error
WriteFieldBegin(name string, typeId TType, id int16) error
WriteFieldEnd() error
WriteFieldStop() error
WriteMapBegin(keyType TType, valueType TType, size int) error
WriteMapEnd() error
WriteListBegin(elemType TType, size int) error
WriteListEnd() error
WriteSetBegin(elemType TType, size int) error
WriteSetEnd() error
WriteBool(value bool) error
WriteByte(value int8) error
WriteI16(value int16) error
WriteI32(value int32) error
WriteI64(value int64) error
WriteDouble(value float64) error
WriteString(value string) error
WriteBinary(value []byte) error
ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error)
ReadMessageEnd() error
ReadStructBegin() (name string, err error)
ReadStructEnd() error
ReadFieldBegin() (name string, typeId TType, id int16, err error)
ReadFieldEnd() error
ReadMapBegin() (keyType TType, valueType TType, size int, err error)
ReadMapEnd() error
ReadListBegin() (elemType TType, size int, err error)
ReadListEnd() error
ReadSetBegin() (elemType TType, size int, err error)
ReadSetEnd() error
ReadBool() (value bool, err error)
ReadByte() (value int8, err error)
ReadI16() (value int16, err error)
ReadI32() (value int32, err error)
ReadI64() (value int64, err error)
ReadDouble() (value float64, err error)
ReadString() (value string, err error)
ReadBinary() (value []byte, err error)
Skip(fieldType TType) (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
}
// The maximum recursive depth the skip() function will traverse
const DEFAULT_RECURSION_DEPTH = 64
// Skips over the next data element from the provided input TProtocol object.
func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) {
return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH)
}
// Skips over the next data element from the provided input TProtocol object.
func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded"))
}
switch fieldType {
case STOP:
return
case BOOL:
_, err = self.ReadBool()
return
case BYTE:
_, err = self.ReadByte()
return
case I16:
_, err = self.ReadI16()
return
case I32:
_, err = self.ReadI32()
return
case I64:
_, err = self.ReadI64()
return
case DOUBLE:
_, err = self.ReadDouble()
return
case STRING:
_, err = self.ReadString()
return
case STRUCT:
if _, err = self.ReadStructBegin(); err != nil {
return err
}
for {
_, typeId, _, _ := self.ReadFieldBegin()
if typeId == STOP {
break
}
err := Skip(self, typeId, maxDepth-1)
if err != nil {
return err
}
self.ReadFieldEnd()
}
return self.ReadStructEnd()
case MAP:
keyType, valueType, size, err := self.ReadMapBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, keyType, maxDepth-1)
if err != nil {
return err
}
self.Skip(valueType)
}
return self.ReadMapEnd()
case SET:
elemType, size, err := self.ReadSetBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadSetEnd()
case LIST:
elemType, size, err := self.ReadListBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadListEnd()
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
return nil
}

View File

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"encoding/base64"
)
// Thrift Protocol exception
type TProtocolException interface {
TException
TypeId() int
}
const (
UNKNOWN_PROTOCOL_EXCEPTION = 0
INVALID_DATA = 1
NEGATIVE_SIZE = 2
SIZE_LIMIT = 3
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
)
type tProtocolException struct {
typeId int
message string
}
func (p *tProtocolException) TypeId() int {
return p.typeId
}
func (p *tProtocolException) String() string {
return p.message
}
func (p *tProtocolException) Error() string {
return p.message
}
func NewTProtocolException(err error) TProtocolException {
if err == nil {
return nil
}
if e, ok := err.(TProtocolException); ok {
return e
}
if _, ok := err.(base64.CorruptInputError); ok {
return &tProtocolException{INVALID_DATA, err.Error()}
}
return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()}
}
func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
if err == nil {
return nil
}
return &tProtocolException{errType, err.Error()}
}

View File

@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Factory interface for constructing protocol instances.
type TProtocolFactory interface {
GetProtocol(trans TTransport) TProtocol
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "io"
type RichTransport struct {
TTransport
}
// Wraps Transport to provide TRichTransport interface
func NewTRichTransport(trans TTransport) *RichTransport {
return &RichTransport{trans}
}
func (r *RichTransport) ReadByte() (c byte, err error) {
return readByte(r.TTransport)
}
func (r *RichTransport) WriteByte(c byte) error {
return writeByte(r.TTransport, c)
}
func (r *RichTransport) WriteString(s string) (n int, err error) {
return r.Write([]byte(s))
}
func (r *RichTransport) RemainingBytes() (num_bytes uint64) {
return r.TTransport.RemainingBytes()
}
func readByte(r io.Reader) (c byte, err error) {
v := [1]byte{0}
n, err := r.Read(v[0:1])
if n > 0 && (err == nil || err == io.EOF) {
return v[0], nil
}
if n > 0 && err != nil {
return v[0], err
}
if err != nil {
return 0, err
}
return v[0], nil
}
func writeByte(w io.Writer, c byte) error {
v := [1]byte{c}
_, err := w.Write(v[0:1])
return err
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
)
type TSerializer struct {
Transport *TMemoryBuffer
Protocol TProtocol
}
type TStruct interface {
Write(p TProtocol) error
Read(p TProtocol) error
}
func NewTSerializer() *TSerializer {
transport := NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
return &TSerializer{
transport,
protocol}
}
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(ctx); err != nil {
return
}
return t.Transport.String(), nil
}
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(ctx); err != nil {
return
}
b = append(b, t.Transport.Bytes()...)
return
}

View File

@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
type TServer interface {
ProcessorFactory() TProcessorFactory
ServerTransport() TServerTransport
InputTransportFactory() TTransportFactory
OutputTransportFactory() TTransportFactory
InputProtocolFactory() TProtocolFactory
OutputProtocolFactory() TProtocolFactory
// Starts the server
Serve() error
// Stops the server. This is optional on a per-implementation basis. Not
// all servers are required to be cleanly stoppable.
Stop() error
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"net"
"sync"
"time"
)
type TServerSocket struct {
listener net.Listener
addr net.Addr
clientTimeout time.Duration
// Protects the interrupted value to make it thread safe.
mu sync.RWMutex
interrupted bool
}
func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
return NewTServerSocketTimeout(listenAddr, 0)
}
func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*TServerSocket, error) {
addr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err
}
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil
}
// Creates a TServerSocket from a net.Addr
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
}
func (p *TServerSocket) Listen() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return nil
}
l, err := net.Listen(p.addr.Network(), p.addr.String())
if err != nil {
return err
}
p.listener = l
return nil
}
func (p *TServerSocket) Accept() (TTransport, error) {
p.mu.RLock()
interrupted := p.interrupted
p.mu.RUnlock()
if interrupted {
return nil, errTransportInterrupted
}
listener := p.listener
if listener == nil {
return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
}
conn, err := listener.Accept()
if err != nil {
return nil, NewTTransportExceptionFromError(err)
}
return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil
}
// Checks whether the socket is listening.
func (p *TServerSocket) IsListening() bool {
return p.listener != nil
}
// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil {
return err
} else {
p.listener = l
}
return nil
}
func (p *TServerSocket) Addr() net.Addr {
if p.listener != nil {
return p.listener.Addr()
}
return p.addr
}
func (p *TServerSocket) Close() error {
defer func() {
p.listener = nil
}()
if p.IsListening() {
return p.listener.Close()
}
return nil
}
func (p *TServerSocket) Interrupt() error {
p.mu.Lock()
defer p.mu.Unlock()
p.interrupted = true
p.Close()
return nil
}

View File

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Server transport. Object which provides client transports.
type TServerTransport interface {
Listen() error
Accept() (TTransport, error)
Close() error
// Optional method implementation. This signals to the server transport
// that it should break out of any accept() or listen() that it is currently
// blocked on. This method, if implemented, MUST be thread safe, as it may
// be called from a different thread context than the other TServerTransport
// methods.
Interrupt() error
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"log"
"runtime/debug"
"sync"
"sync/atomic"
)
/*
* This is not a typical TSimpleServer as it is not blocked after accept a socket.
* It is more like a TThreadedServer that can handle different connections in different goroutines.
* This will work if golang user implements a conn-pool like thing in client side.
*/
type TSimpleServer struct {
closed int32
wg sync.WaitGroup
mu sync.Mutex
processorFactory TProcessorFactory
serverTransport TServerTransport
inputTransportFactory TTransportFactory
outputTransportFactory TTransportFactory
inputProtocolFactory TProtocolFactory
outputProtocolFactory TProtocolFactory
}
func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport)
}
func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer {
return NewTSimpleServerFactory4(NewTProcessorFactory(processor),
serverTransport,
transportFactory,
protocolFactory,
)
}
func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
return NewTSimpleServerFactory6(NewTProcessorFactory(processor),
serverTransport,
inputTransportFactory,
outputTransportFactory,
inputProtocolFactory,
outputProtocolFactory,
)
}
func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer {
return NewTSimpleServerFactory6(processorFactory,
serverTransport,
NewTTransportFactory(),
NewTTransportFactory(),
NewTBinaryProtocolFactoryDefault(),
NewTBinaryProtocolFactoryDefault(),
)
}
func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer {
return NewTSimpleServerFactory6(processorFactory,
serverTransport,
transportFactory,
transportFactory,
protocolFactory,
protocolFactory,
)
}
func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
return &TSimpleServer{
processorFactory: processorFactory,
serverTransport: serverTransport,
inputTransportFactory: inputTransportFactory,
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
}
}
func (p *TSimpleServer) ProcessorFactory() TProcessorFactory {
return p.processorFactory
}
func (p *TSimpleServer) ServerTransport() TServerTransport {
return p.serverTransport
}
func (p *TSimpleServer) InputTransportFactory() TTransportFactory {
return p.inputTransportFactory
}
func (p *TSimpleServer) OutputTransportFactory() TTransportFactory {
return p.outputTransportFactory
}
func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory {
return p.inputProtocolFactory
}
func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory {
return p.outputProtocolFactory
}
func (p *TSimpleServer) Listen() error {
return p.serverTransport.Listen()
}
func (p *TSimpleServer) innerAccept() (int32, error) {
client, err := p.serverTransport.Accept()
p.mu.Lock()
defer p.mu.Unlock()
closed := atomic.LoadInt32(&p.closed)
if closed != 0 {
return closed, nil
}
if err != nil {
return 0, err
}
if client != nil {
p.wg.Add(1)
go func() {
defer p.wg.Done()
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
}
}()
}
return 0, nil
}
func (p *TSimpleServer) AcceptLoop() error {
for {
closed, err := p.innerAccept()
if err != nil {
return err
}
if closed != 0 {
return nil
}
}
}
func (p *TSimpleServer) Serve() error {
err := p.Listen()
if err != nil {
return err
}
p.AcceptLoop()
return nil
}
func (p *TSimpleServer) Stop() error {
p.mu.Lock()
defer p.mu.Unlock()
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
atomic.StoreInt32(&p.closed, 1)
p.serverTransport.Interrupt()
p.wg.Wait()
return nil
}
func (p *TSimpleServer) processRequests(client TTransport) error {
processor := p.processorFactory.GetProcessor(client)
inputTransport, err := p.inputTransportFactory.GetTransport(client)
if err != nil {
return err
}
outputTransport, err := p.outputTransportFactory.GetTransport(client)
if err != nil {
return err
}
inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport)
defer func() {
if e := recover(); e != nil {
log.Printf("panic in processor: %s: %s", e, debug.Stack())
}
}()
if inputTransport != nil {
defer inputTransport.Close()
}
if outputTransport != nil {
defer outputTransport.Close()
}
for {
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
ok, err := processor.Process(defaultCtx, inputProtocol, outputProtocol)
if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
return nil
} else if err != nil {
return err
}
if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD {
continue
}
if !ok {
break
}
}
return nil
}

View File

@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"net"
"time"
)
type TSocket struct {
conn net.Conn
addr net.Addr
timeout time.Duration
}
// NewTSocket creates a net.Conn-backed TTransport, given a host and port
//
// Example:
// trans, err := thrift.NewTSocket("localhost:9090")
func NewTSocket(hostPort string) (*TSocket, error) {
return NewTSocketTimeout(hostPort, 0)
}
// NewTSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a timeout as a time.Duration
func NewTSocketTimeout(hostPort string, timeout time.Duration) (*TSocket, error) {
//conn, err := net.DialTimeout(network, address, timeout)
addr, err := net.ResolveTCPAddr("tcp", hostPort)
if err != nil {
return nil, err
}
return NewTSocketFromAddrTimeout(addr, timeout), nil
}
// Creates a TSocket from a net.Addr
func NewTSocketFromAddrTimeout(addr net.Addr, timeout time.Duration) *TSocket {
return &TSocket{addr: addr, timeout: timeout}
}
// Creates a TSocket from an existing net.Conn
func NewTSocketFromConnTimeout(conn net.Conn, timeout time.Duration) *TSocket {
return &TSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout}
}
// Sets the socket timeout
func (p *TSocket) SetTimeout(timeout time.Duration) error {
p.timeout = timeout
return nil
}
func (p *TSocket) pushDeadline(read, write bool) {
var t time.Time
if p.timeout > 0 {
t = time.Now().Add(time.Duration(p.timeout))
}
if read && write {
p.conn.SetDeadline(t)
} else if read {
p.conn.SetReadDeadline(t)
} else if write {
p.conn.SetWriteDeadline(t)
}
}
// Connects the socket, creating a new socket object if necessary.
func (p *TSocket) Open() error {
if p.IsOpen() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
return NewTTransportException(NOT_OPEN, "Cannot open nil address.")
}
if len(p.addr.Network()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad network name.")
}
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
var err error
if p.conn, err = net.DialTimeout(p.addr.Network(), p.addr.String(), p.timeout); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
return nil
}
// Retrieve the underlying net.Conn
func (p *TSocket) Conn() net.Conn {
return p.conn
}
// Returns true if the connection is open
func (p *TSocket) IsOpen() bool {
if p.conn == nil {
return false
}
return true
}
// Closes the socket.
func (p *TSocket) Close() error {
// Close the socket
if p.conn != nil {
err := p.conn.Close()
if err != nil {
return err
}
p.conn = nil
}
return nil
}
//Returns the remote address of the socket.
func (p *TSocket) Addr() net.Addr {
return p.addr
}
func (p *TSocket) Read(buf []byte) (int, error) {
if !p.IsOpen() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
n, err := p.conn.Read(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TSocket) Write(buf []byte) (int, error) {
if !p.IsOpen() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
return p.conn.Write(buf)
}
func (p *TSocket) Flush(ctx context.Context) error {
return nil
}
func (p *TSocket) Interrupt() error {
if !p.IsOpen() {
return nil
}
return p.conn.Close()
}
func (p *TSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"crypto/tls"
"net"
"time"
)
type TSSLServerSocket struct {
listener net.Listener
addr net.Addr
clientTimeout time.Duration
interrupted bool
cfg *tls.Config
}
func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket, error) {
return NewTSSLServerSocketTimeout(listenAddr, cfg, 0)
}
func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) {
if cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
addr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err
}
return &TSSLServerSocket{addr: addr, clientTimeout: clientTimeout, cfg: cfg}, nil
}
func (p *TSSLServerSocket) Listen() error {
if p.IsListening() {
return nil
}
l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg)
if err != nil {
return err
}
p.listener = l
return nil
}
func (p *TSSLServerSocket) Accept() (TTransport, error) {
if p.interrupted {
return nil, errTransportInterrupted
}
if p.listener == nil {
return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
}
conn, err := p.listener.Accept()
if err != nil {
return nil, NewTTransportExceptionFromError(err)
}
return NewTSSLSocketFromConnTimeout(conn, p.cfg, p.clientTimeout), nil
}
// Checks whether the socket is listening.
func (p *TSSLServerSocket) IsListening() bool {
return p.listener != nil
}
// Connects the socket, creating a new socket object if necessary.
func (p *TSSLServerSocket) Open() error {
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
if l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg); err != nil {
return err
} else {
p.listener = l
}
return nil
}
func (p *TSSLServerSocket) Addr() net.Addr {
return p.addr
}
func (p *TSSLServerSocket) Close() error {
defer func() {
p.listener = nil
}()
if p.IsListening() {
return p.listener.Close()
}
return nil
}
func (p *TSSLServerSocket) Interrupt() error {
p.interrupted = true
return nil
}

View File

@ -0,0 +1,176 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"crypto/tls"
"net"
"time"
)
type TSSLSocket struct {
conn net.Conn
// hostPort contains host:port (e.g. "asdf.com:12345"). The field is
// only valid if addr is nil.
hostPort string
// addr is nil when hostPort is not "", and is only used when the
// TSSLSocket is constructed from a net.Addr.
addr net.Addr
timeout time.Duration
cfg *tls.Config
}
// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration
//
// Example:
// trans, err := thrift.NewTSSLSocket("localhost:9090", nil)
func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
return NewTSSLSocketTimeout(hostPort, cfg, 0)
}
// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a tls Configuration and a timeout as a time.Duration
func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) {
if cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil
}
// Creates a TSSLSocket from a net.Addr
func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg}
}
// Creates a TSSLSocket from an existing net.Conn
func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
}
// Sets the socket timeout
func (p *TSSLSocket) SetTimeout(timeout time.Duration) error {
p.timeout = timeout
return nil
}
func (p *TSSLSocket) pushDeadline(read, write bool) {
var t time.Time
if p.timeout > 0 {
t = time.Now().Add(time.Duration(p.timeout))
}
if read && write {
p.conn.SetDeadline(t)
} else if read {
p.conn.SetReadDeadline(t)
} else if write {
p.conn.SetWriteDeadline(t)
}
}
// Connects the socket, creating a new socket object if necessary.
func (p *TSSLSocket) Open() error {
var err error
// If we have a hostname, we need to pass the hostname to tls.Dial for
// certificate hostname checks.
if p.hostPort != "" {
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
} else {
if p.IsOpen() {
return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
}
if p.addr == nil {
return NewTTransportException(NOT_OPEN, "Cannot open nil address.")
}
if len(p.addr.Network()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad network name.")
}
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
}
return nil
}
// Retrieve the underlying net.Conn
func (p *TSSLSocket) Conn() net.Conn {
return p.conn
}
// Returns true if the connection is open
func (p *TSSLSocket) IsOpen() bool {
if p.conn == nil {
return false
}
return true
}
// Closes the socket.
func (p *TSSLSocket) Close() error {
// Close the socket
if p.conn != nil {
err := p.conn.Close()
if err != nil {
return err
}
p.conn = nil
}
return nil
}
func (p *TSSLSocket) Read(buf []byte) (int, error) {
if !p.IsOpen() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(true, false)
n, err := p.conn.Read(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TSSLSocket) Write(buf []byte) (int, error) {
if !p.IsOpen() {
return 0, NewTTransportException(NOT_OPEN, "Connection not open")
}
p.pushDeadline(false, true)
return p.conn.Write(buf)
}
func (p *TSSLSocket) Flush(ctx context.Context) error {
return nil
}
func (p *TSSLSocket) Interrupt() error {
if !p.IsOpen() {
return nil
}
return p.conn.Close()
}
func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"errors"
"io"
)
var errTransportInterrupted = errors.New("Transport Interrupted")
type Flusher interface {
Flush() (err error)
}
type ContextFlusher interface {
Flush(ctx context.Context) (err error)
}
type ReadSizeProvider interface {
RemainingBytes() (num_bytes uint64)
}
// Encapsulates the I/O layer
type TTransport interface {
io.ReadWriteCloser
ContextFlusher
ReadSizeProvider
// Opens the transport for communication
Open() error
// Returns true if the transport is open
IsOpen() bool
}
type stringWriter interface {
WriteString(s string) (n int, err error)
}
// This is "enchanced" transport with extra capabilities. You need to use one of these
// to construct protocol.
// Notably, TSocket does not implement this interface, and it is always a mistake to use
// TSocket directly in protocol.
type TRichTransport interface {
io.ReadWriter
io.ByteReader
io.ByteWriter
stringWriter
ContextFlusher
ReadSizeProvider
}

View File

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
"io"
)
type timeoutable interface {
Timeout() bool
}
// Thrift Transport exception
type TTransportException interface {
TException
TypeId() int
Err() error
}
const (
UNKNOWN_TRANSPORT_EXCEPTION = 0
NOT_OPEN = 1
ALREADY_OPEN = 2
TIMED_OUT = 3
END_OF_FILE = 4
)
type tTransportException struct {
typeId int
err error
}
func (p *tTransportException) TypeId() int {
return p.typeId
}
func (p *tTransportException) Error() string {
return p.err.Error()
}
func (p *tTransportException) Err() error {
return p.err
}
func NewTTransportException(t int, e string) TTransportException {
return &tTransportException{typeId: t, err: errors.New(e)}
}
func NewTTransportExceptionFromError(e error) TTransportException {
if e == nil {
return nil
}
if t, ok := e.(TTransportException); ok {
return t
}
switch v := e.(type) {
case TTransportException:
return v
case timeoutable:
if v.Timeout() {
return &tTransportException{typeId: TIMED_OUT, err: e}
}
}
if e == io.EOF {
return &tTransportException{typeId: END_OF_FILE, err: e}
}
return &tTransportException{typeId: UNKNOWN_TRANSPORT_EXCEPTION, err: e}
}

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Factory class used to create wrapped instance of Transports.
// This is used primarily in servers, which get Transports from
// a ServerTransport and then may want to mutate them (i.e. create
// a BufferedTransport from the underlying base transport)
type TTransportFactory interface {
GetTransport(trans TTransport) (TTransport, error)
}
type tTransportFactory struct{}
// Return a wrapped instance of the base Transport.
func (p *tTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return trans, nil
}
func NewTTransportFactory() TTransportFactory {
return &tTransportFactory{}
}

69
vendor/git.apache.org/thrift.git/lib/go/thrift/type.go generated vendored Normal file
View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Type constants in the Thrift protocol
type TType byte
const (
STOP = 0
VOID = 1
BOOL = 2
BYTE = 3
I08 = 3
DOUBLE = 4
I16 = 6
I32 = 8
I64 = 10
STRING = 11
UTF7 = 11
STRUCT = 12
MAP = 13
SET = 14
LIST = 15
UTF8 = 16
UTF16 = 17
//BINARY = 18 wrong and unusued
)
var typeNames = map[int]string{
STOP: "STOP",
VOID: "VOID",
BOOL: "BOOL",
BYTE: "BYTE",
DOUBLE: "DOUBLE",
I16: "I16",
I32: "I32",
I64: "I64",
STRING: "STRING",
STRUCT: "STRUCT",
MAP: "MAP",
SET: "SET",
LIST: "LIST",
UTF8: "UTF8",
UTF16: "UTF16",
}
func (p TType) String() string {
if s, ok := typeNames[int(p)]; ok {
return s
}
return "Unknown"
}

View File

@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"compress/zlib"
"context"
"io"
"log"
)
// TZlibTransportFactory is a factory for TZlibTransport instances
type TZlibTransportFactory struct {
level int
factory TTransportFactory
}
// TZlibTransport is a TTransport implementation that makes use of zlib compression.
type TZlibTransport struct {
reader io.ReadCloser
transport TTransport
writer *zlib.Writer
}
// GetTransport constructs a new instance of NewTZlibTransport
func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if p.factory != nil {
// wrap other factory
var err error
trans, err = p.factory.GetTransport(trans)
if err != nil {
return nil, err
}
}
return NewTZlibTransport(trans, p.level)
}
// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory
func NewTZlibTransportFactory(level int) *TZlibTransportFactory {
return &TZlibTransportFactory{level: level, factory: nil}
}
// NewTZlibTransportFactory constructs a new instance of TZlibTransportFactory
// as a wrapper over existing transport factory
func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory {
return &TZlibTransportFactory{level: level, factory: factory}
}
// NewTZlibTransport constructs a new instance of TZlibTransport
func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) {
w, err := zlib.NewWriterLevel(trans, level)
if err != nil {
log.Println(err)
return nil, err
}
return &TZlibTransport{
writer: w,
transport: trans,
}, nil
}
// Close closes the reader and writer (flushing any unwritten data) and closes
// the underlying transport.
func (z *TZlibTransport) Close() error {
if z.reader != nil {
if err := z.reader.Close(); err != nil {
return err
}
}
if err := z.writer.Close(); err != nil {
return err
}
return z.transport.Close()
}
// Flush flushes the writer and its underlying transport.
func (z *TZlibTransport) Flush(ctx context.Context) error {
if err := z.writer.Flush(); err != nil {
return err
}
return z.transport.Flush(ctx)
}
// IsOpen returns true if the transport is open
func (z *TZlibTransport) IsOpen() bool {
return z.transport.IsOpen()
}
// Open opens the transport for communication
func (z *TZlibTransport) Open() error {
return z.transport.Open()
}
func (z *TZlibTransport) Read(p []byte) (int, error) {
if z.reader == nil {
r, err := zlib.NewReader(z.transport)
if err != nil {
return 0, NewTTransportExceptionFromError(err)
}
z.reader = r
}
return z.reader.Read(p)
}
// RemainingBytes returns the size in bytes of the data that is still to be
// read.
func (z *TZlibTransport) RemainingBytes() uint64 {
return z.transport.RemainingBytes()
}
func (z *TZlibTransport) Write(p []byte) (int, error) {
return z.writer.Write(p)
}

202
vendor/github.com/minio/parquet-go/LICENSE generated vendored Normal file
View File

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

36
vendor/github.com/minio/parquet-go/Makefile generated vendored Normal file
View File

@ -0,0 +1,36 @@
GOPATH := $(shell go env GOPATH)
all: check
getdeps:
@if [ ! -f ${GOPATH}/bin/golint ]; then echo "Installing golint" && go get -u golang.org/x/lint/golint; fi
@if [ ! -f ${GOPATH}/bin/gocyclo ]; then echo "Installing gocyclo" && go get -u github.com/fzipp/gocyclo; fi
@if [ ! -f ${GOPATH}/bin/misspell ]; then echo "Installing misspell" && go get -u github.com/client9/misspell/cmd/misspell; fi
@if [ ! -f ${GOPATH}/bin/ineffassign ]; then echo "Installing ineffassign" && go get -u github.com/gordonklaus/ineffassign; fi
vet:
@echo "Running $@"
@go tool vet -atomic -bool -copylocks -nilfunc -printf -shadow -rangeloops -unreachable -unsafeptr -unusedresult *.go
fmt:
@echo "Running $@"
@gofmt -d *.go
lint:
@echo "Running $@"
@${GOPATH}/bin/golint -set_exit_status
cyclo:
@echo "Running $@"
@${GOPATH}/bin/gocyclo -over 200 .
spelling:
@${GOPATH}/bin/misspell -locale US -error *.go README.md
ineffassign:
@echo "Running $@"
@${GOPATH}/bin/ineffassign .
check: getdeps vet fmt lint cyclo spelling ineffassign
@echo "Running unit tests"
@go test -tags kqueue .

1
vendor/github.com/minio/parquet-go/README.md generated vendored Normal file
View File

@ -0,0 +1 @@
# parquet-go

152
vendor/github.com/minio/parquet-go/column.go generated vendored Normal file
View File

@ -0,0 +1,152 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"io"
"strings"
"git.apache.org/thrift.git/lib/go/thrift"
"github.com/minio/minio-go/pkg/set"
"github.com/minio/parquet-go/gen-go/parquet"
)
func getColumns(
rowGroup *parquet.RowGroup,
columnNames set.StringSet,
schemaElements []*parquet.SchemaElement,
getReaderFunc GetReaderFunc,
) (nameColumnMap map[string]*column, err error) {
nameIndexMap := make(map[string]int)
for colIndex, columnChunk := range rowGroup.GetColumns() {
meta := columnChunk.GetMetaData()
columnName := strings.Join(meta.GetPathInSchema(), ".")
if columnNames != nil && !columnNames.Contains(columnName) {
continue
}
// Ignore column spanning into another file.
if columnChunk.GetFilePath() != "" {
continue
}
offset := meta.GetDataPageOffset()
if meta.DictionaryPageOffset != nil {
offset = meta.GetDictionaryPageOffset()
}
size := meta.GetTotalCompressedSize()
rc, err := getReaderFunc(offset, size)
if err != nil {
return nil, err
}
thriftReader := thrift.NewTBufferedTransport(thrift.NewStreamTransportR(rc), int(size))
if nameColumnMap == nil {
nameColumnMap = make(map[string]*column)
}
nameColumnMap[columnName] = &column{
name: columnName,
metadata: meta,
schemaElements: schemaElements,
rc: rc,
thriftReader: thriftReader,
valueType: meta.GetType(),
}
nameIndexMap[columnName] = colIndex
}
for name := range nameColumnMap {
nameColumnMap[name].nameIndexMap = nameIndexMap
}
return nameColumnMap, nil
}
type column struct {
name string
endOfValues bool
valueIndex int
valueType parquet.Type
metadata *parquet.ColumnMetaData
schemaElements []*parquet.SchemaElement
nameIndexMap map[string]int
dictPage *page
dataTable *table
rc io.ReadCloser
thriftReader *thrift.TBufferedTransport
}
func (column *column) close() (err error) {
if column.rc != nil {
err = column.rc.Close()
column.rc = nil
}
return err
}
func (column *column) readPage() {
page, _, _, err := readPage(
column.thriftReader,
column.metadata,
column.nameIndexMap,
column.schemaElements,
)
if err != nil {
column.endOfValues = true
return
}
if page.Header.GetType() == parquet.PageType_DICTIONARY_PAGE {
column.dictPage = page
column.readPage()
return
}
page.decode(column.dictPage)
if column.dataTable == nil {
column.dataTable = newTableFromTable(page.DataTable)
}
column.dataTable.Merge(page.DataTable)
}
func (column *column) read() (value interface{}, valueType parquet.Type) {
if column.dataTable == nil {
column.readPage()
column.valueIndex = 0
}
if column.endOfValues {
return nil, column.metadata.GetType()
}
value = column.dataTable.Values[column.valueIndex]
column.valueIndex++
if len(column.dataTable.Values) == column.valueIndex {
column.dataTable = nil
}
return value, column.metadata.GetType()
}

57
vendor/github.com/minio/parquet-go/compression.go generated vendored Normal file
View File

@ -0,0 +1,57 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"bytes"
"compress/gzip"
"fmt"
"io/ioutil"
"github.com/golang/snappy"
"github.com/minio/parquet-go/gen-go/parquet"
"github.com/pierrec/lz4"
lzo "github.com/rasky/go-lzo"
)
type compressionCodec parquet.CompressionCodec
func (c compressionCodec) uncompress(buf []byte) ([]byte, error) {
switch parquet.CompressionCodec(c) {
case parquet.CompressionCodec_UNCOMPRESSED:
return buf, nil
case parquet.CompressionCodec_SNAPPY:
return snappy.Decode(nil, buf)
case parquet.CompressionCodec_GZIP:
reader, err := gzip.NewReader(bytes.NewReader(buf))
if err != nil {
return nil, err
}
defer reader.Close()
return ioutil.ReadAll(reader)
case parquet.CompressionCodec_LZO:
return lzo.Decompress1X(bytes.NewReader(buf), len(buf), 0)
case parquet.CompressionCodec_LZ4:
return ioutil.ReadAll(lz4.NewReader(bytes.NewReader(buf)))
}
return nil, fmt.Errorf("invalid compression codec %v", c)
}

506
vendor/github.com/minio/parquet-go/decode.go generated vendored Normal file
View File

@ -0,0 +1,506 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"github.com/minio/parquet-go/gen-go/parquet"
)
func uint32ToBytes(v uint32) []byte {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, v)
return buf
}
func uint64ToBytes(v uint64) []byte {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, v)
return buf
}
func bytesToUint32(buf []byte) uint32 {
return binary.LittleEndian.Uint32(buf)
}
func bytesToUint64(buf []byte) uint64 {
return binary.LittleEndian.Uint64(buf)
}
func i64sToi32s(i64s []int64) (i32s []int32) {
i32s = make([]int32, len(i64s))
for i := range i64s {
i32s[i] = int32(i64s[i])
}
return i32s
}
func readBitPacked(reader *bytes.Reader, header, bitWidth uint64) (result []int64, err error) {
count := header * 8
if count == 0 {
return result, nil
}
if bitWidth == 0 {
return make([]int64, count), nil
}
data := make([]byte, header*bitWidth)
if _, err = reader.Read(data); err != nil {
return nil, err
}
var val, used, left, b uint64
valNeedBits := bitWidth
i := -1
for {
if left <= 0 {
i++
if i >= len(data) {
break
}
b = uint64(data[i])
left = 8
used = 0
}
if left >= valNeedBits {
val |= ((b >> used) & ((1 << valNeedBits) - 1)) << (bitWidth - valNeedBits)
result = append(result, int64(val))
val = 0
left -= valNeedBits
used += valNeedBits
valNeedBits = bitWidth
} else {
val |= (b >> used) << (bitWidth - valNeedBits)
valNeedBits -= left
left = 0
}
}
return result, nil
}
func readBools(reader *bytes.Reader, count uint64) (result []bool, err error) {
i64s, err := readBitPacked(reader, count, 1)
if err != nil {
return nil, err
}
var i uint64
for i = 0; i < count; i++ {
result = append(result, i64s[i] > 0)
}
return result, nil
}
func readInt32s(reader *bytes.Reader, count uint64) (result []int32, err error) {
buf := make([]byte, 4)
var i uint64
for i = 0; i < count; i++ {
if _, err = reader.Read(buf); err != nil {
return nil, err
}
result = append(result, int32(bytesToUint32(buf)))
}
return result, nil
}
func readInt64s(reader *bytes.Reader, count uint64) (result []int64, err error) {
buf := make([]byte, 8)
var i uint64
for i = 0; i < count; i++ {
if _, err = reader.Read(buf); err != nil {
return nil, err
}
result = append(result, int64(bytesToUint64(buf)))
}
return result, nil
}
func readInt96s(reader *bytes.Reader, count uint64) (result [][]byte, err error) {
var i uint64
for i = 0; i < count; i++ {
buf := make([]byte, 12)
if _, err = reader.Read(buf); err != nil {
return nil, err
}
result = append(result, buf)
}
return result, nil
}
func readFloats(reader *bytes.Reader, count uint64) (result []float32, err error) {
buf := make([]byte, 4)
var i uint64
for i = 0; i < count; i++ {
if _, err = reader.Read(buf); err != nil {
return nil, err
}
result = append(result, math.Float32frombits(bytesToUint32(buf)))
}
return result, nil
}
func readDoubles(reader *bytes.Reader, count uint64) (result []float64, err error) {
buf := make([]byte, 8)
var i uint64
for i = 0; i < count; i++ {
if _, err = reader.Read(buf); err != nil {
return nil, err
}
result = append(result, math.Float64frombits(bytesToUint64(buf)))
}
return result, nil
}
func readByteArrays(reader *bytes.Reader, count uint64) (result [][]byte, err error) {
buf := make([]byte, 4)
var i uint64
for i = 0; i < count; i++ {
if _, err = reader.Read(buf); err != nil {
return nil, err
}
data := make([]byte, bytesToUint32(buf))
if _, err = reader.Read(data); err != nil {
return nil, err
}
result = append(result, data)
}
return result, nil
}
func readFixedLenByteArrays(reader *bytes.Reader, count, length uint64) (result [][]byte, err error) {
var i uint64
for i = 0; i < count; i++ {
data := make([]byte, length)
if _, err = reader.Read(data); err != nil {
return nil, err
}
result = append(result, data)
}
return result, nil
}
func readValues(reader *bytes.Reader, dataType parquet.Type, count, length uint64) (interface{}, error) {
switch dataType {
case parquet.Type_BOOLEAN:
return readBools(reader, count)
case parquet.Type_INT32:
return readInt32s(reader, count)
case parquet.Type_INT64:
return readInt64s(reader, count)
case parquet.Type_INT96:
return readInt96s(reader, count)
case parquet.Type_FLOAT:
return readFloats(reader, count)
case parquet.Type_DOUBLE:
return readDoubles(reader, count)
case parquet.Type_BYTE_ARRAY:
return readByteArrays(reader, count)
case parquet.Type_FIXED_LEN_BYTE_ARRAY:
return readFixedLenByteArrays(reader, count, length)
}
return nil, fmt.Errorf("unknown parquet type %v", dataType)
}
func readUnsignedVarInt(reader *bytes.Reader) (v uint64, err error) {
var b byte
var shift uint64
for {
if b, err = reader.ReadByte(); err != nil {
return 0, err
}
if v |= ((uint64(b) & 0x7F) << shift); b&0x80 == 0 {
break
}
shift += 7
}
return v, nil
}
func readRLE(reader *bytes.Reader, header, bitWidth uint64) (result []int64, err error) {
width := (bitWidth + 7) / 8
data := make([]byte, width)
if width > 0 {
if _, err = reader.Read(data); err != nil {
return nil, err
}
}
if width < 4 {
data = append(data, make([]byte, 4-width)...)
}
val := int64(bytesToUint32(data))
count := header >> 1
result = make([]int64, count)
for i := range result {
result[i] = val
}
return result, nil
}
func readRLEBitPackedHybrid(reader *bytes.Reader, length, bitWidth uint64) (result []int64, err error) {
if length <= 0 {
var i32s []int32
i32s, err = readInt32s(reader, 1)
if err != nil {
return nil, err
}
length = uint64(i32s[0])
}
buf := make([]byte, length)
if _, err = reader.Read(buf); err != nil {
return nil, err
}
reader = bytes.NewReader(buf)
for reader.Len() > 0 {
header, err := readUnsignedVarInt(reader)
if err != nil {
return nil, err
}
var i64s []int64
if header&1 == 0 {
i64s, err = readRLE(reader, header, bitWidth)
} else {
i64s, err = readBitPacked(reader, header>>1, bitWidth)
}
if err != nil {
return nil, err
}
result = append(result, i64s...)
}
return result, nil
}
func readDeltaBinaryPackedInt(reader *bytes.Reader) (result []int64, err error) {
blockSize, err := readUnsignedVarInt(reader)
if err != nil {
return nil, err
}
numMiniblocksInBlock, err := readUnsignedVarInt(reader)
if err != nil {
return nil, err
}
numValues, err := readUnsignedVarInt(reader)
if err != nil {
return nil, err
}
firstValueZigZag, err := readUnsignedVarInt(reader)
if err != nil {
return nil, err
}
v := int64(firstValueZigZag>>1) ^ (-int64(firstValueZigZag & 1))
result = append(result, v)
numValuesInMiniBlock := blockSize / numMiniblocksInBlock
bitWidths := make([]uint64, numMiniblocksInBlock)
for uint64(len(result)) < numValues {
minDeltaZigZag, err := readUnsignedVarInt(reader)
if err != nil {
return nil, err
}
for i := 0; uint64(i) < numMiniblocksInBlock; i++ {
b, err := reader.ReadByte()
if err != nil {
return nil, err
}
bitWidths[i] = uint64(b)
}
minDelta := int64(minDeltaZigZag>>1) ^ (-int64(minDeltaZigZag & 1))
for i := 0; uint64(i) < numMiniblocksInBlock; i++ {
i64s, err := readBitPacked(reader, numValuesInMiniBlock/8, bitWidths[i])
if err != nil {
return nil, err
}
for j := range i64s {
v += i64s[j] + minDelta
result = append(result, v)
}
}
}
return result[:numValues], nil
}
func readDeltaLengthByteArrays(reader *bytes.Reader) (result [][]byte, err error) {
i64s, err := readDeltaBinaryPackedInt(reader)
if err != nil {
return nil, err
}
for i := 0; i < len(i64s); i++ {
arrays, err := readFixedLenByteArrays(reader, 1, uint64(i64s[i]))
if err != nil {
return nil, err
}
result = append(result, arrays[0])
}
return result, nil
}
func readDeltaByteArrays(reader *bytes.Reader) (result [][]byte, err error) {
i64s, err := readDeltaBinaryPackedInt(reader)
if err != nil {
return nil, err
}
suffixes, err := readDeltaLengthByteArrays(reader)
if err != nil {
return nil, err
}
result = append(result, suffixes[0])
for i := 1; i < len(i64s); i++ {
prefixLength := i64s[i]
val := append([]byte{}, result[i-1][:prefixLength]...)
val = append(val, suffixes[i]...)
result = append(result, val)
}
return result, nil
}
func readDataPageValues(
bytesReader *bytes.Reader,
encoding parquet.Encoding,
dataType parquet.Type,
convertedType parquet.ConvertedType,
count, bitWidth uint64,
) (result interface{}, resultDataType parquet.Type, err error) {
switch encoding {
case parquet.Encoding_PLAIN:
result, err = readValues(bytesReader, dataType, count, bitWidth)
return result, dataType, err
case parquet.Encoding_PLAIN_DICTIONARY:
b, err := bytesReader.ReadByte()
if err != nil {
return nil, -1, err
}
i64s, err := readRLEBitPackedHybrid(bytesReader, uint64(bytesReader.Len()), uint64(b))
if err != nil {
return nil, -1, err
}
return i64s[:count], parquet.Type_INT64, nil
case parquet.Encoding_RLE:
i64s, err := readRLEBitPackedHybrid(bytesReader, 0, bitWidth)
if err != nil {
return nil, -1, err
}
i64s = i64s[:count]
if dataType == parquet.Type_INT32 {
return i64sToi32s(i64s), parquet.Type_INT32, nil
}
return i64s, parquet.Type_INT64, nil
case parquet.Encoding_BIT_PACKED:
return nil, -1, fmt.Errorf("deprecated parquet encoding %v", parquet.Encoding_BIT_PACKED)
case parquet.Encoding_DELTA_BINARY_PACKED:
i64s, err := readDeltaBinaryPackedInt(bytesReader)
if err != nil {
return nil, -1, err
}
i64s = i64s[:count]
if dataType == parquet.Type_INT32 {
return i64sToi32s(i64s), parquet.Type_INT32, nil
}
return i64s, parquet.Type_INT64, nil
case parquet.Encoding_DELTA_LENGTH_BYTE_ARRAY:
byteSlices, err := readDeltaLengthByteArrays(bytesReader)
if err != nil {
return nil, -1, err
}
return byteSlices[:count], parquet.Type_FIXED_LEN_BYTE_ARRAY, nil
case parquet.Encoding_DELTA_BYTE_ARRAY:
byteSlices, err := readDeltaByteArrays(bytesReader)
if err != nil {
return nil, -1, err
}
return byteSlices[:count], parquet.Type_FIXED_LEN_BYTE_ARRAY, nil
}
return nil, -1, fmt.Errorf("unsupported parquet encoding %v", encoding)
}

BIN
vendor/github.com/minio/parquet-go/example.parquet generated vendored Normal file

Binary file not shown.

View File

@ -0,0 +1,7 @@
// Autogenerated by Thrift Compiler (0.10.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package parquet
var GoUnusedProtection__ int;

View File

@ -0,0 +1,20 @@
// Autogenerated by Thrift Compiler (0.10.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package parquet
import (
"bytes"
"fmt"
"git.apache.org/thrift.git/lib/go/thrift"
)
// (needed to ensure safety because of naive import list construction.)
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = bytes.Equal
func init() {
}

File diff suppressed because it is too large Load Diff

22
vendor/github.com/minio/parquet-go/gen-parquet-format-pkg.sh generated vendored Executable file
View File

@ -0,0 +1,22 @@
#!/bin/bash
#
# Minio Cloud Storage, (C) 2018 Minio, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
set -e
rm -f parquet.thrift
wget -q https://github.com/apache/parquet-format/raw/df6132b94f273521a418a74442085fdd5a0aa009/src/main/thrift/parquet.thrift
thrift --gen go parquet.thrift

531
vendor/github.com/minio/parquet-go/page.go generated vendored Normal file
View File

@ -0,0 +1,531 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"bytes"
"fmt"
"strings"
"git.apache.org/thrift.git/lib/go/thrift"
"github.com/minio/parquet-go/gen-go/parquet"
)
// getBitWidth - returns bits required to place num e.g.
//
// num | width
// -----|-------
// 0 | 0
// 1 | 1
// 2 | 2
// 3 | 2
// 4 | 3
// 5 | 3
// ... | ...
// ... | ...
//
func getBitWidth(num uint64) (width uint64) {
for ; num != 0; num >>= 1 {
width++
}
return width
}
// getMaxDefLevel - get maximum definition level.
func getMaxDefLevel(nameIndexMap map[string]int, schemaElements []*parquet.SchemaElement, path []string) (v int) {
for i := 1; i <= len(path); i++ {
name := strings.Join(path[:i], ".")
if index, ok := nameIndexMap[name]; ok {
if schemaElements[index].GetRepetitionType() != parquet.FieldRepetitionType_REQUIRED {
v++
}
}
}
return v
}
// getMaxRepLevel - get maximum repetition level.
func getMaxRepLevel(nameIndexMap map[string]int, schemaElements []*parquet.SchemaElement, path []string) (v int) {
for i := 1; i <= len(path); i++ {
name := strings.Join(path[:i], ".")
if index, ok := nameIndexMap[name]; ok {
if schemaElements[index].GetRepetitionType() == parquet.FieldRepetitionType_REPEATED {
v++
}
}
}
return v
}
func readPageHeader(reader *thrift.TBufferedTransport) (*parquet.PageHeader, error) {
pageHeader := parquet.NewPageHeader()
if err := pageHeader.Read(thrift.NewTCompactProtocol(reader)); err != nil {
return nil, err
}
return pageHeader, nil
}
func readPageRawData(thriftReader *thrift.TBufferedTransport, metadata *parquet.ColumnMetaData) (page *page, err error) {
pageHeader, err := readPageHeader(thriftReader)
if err != nil {
return nil, err
}
switch pageType := pageHeader.GetType(); pageType {
case parquet.PageType_DICTIONARY_PAGE:
page = newDictPage()
case parquet.PageType_DATA_PAGE, parquet.PageType_DATA_PAGE_V2:
page = newDataPage()
default:
return nil, fmt.Errorf("unsupported page type %v", pageType)
}
compressedPageSize := pageHeader.GetCompressedPageSize()
buf := make([]byte, compressedPageSize)
if _, err := thriftReader.Read(buf); err != nil {
return nil, err
}
page.Header = pageHeader
page.CompressType = metadata.GetCodec()
page.RawData = buf
page.Path = append([]string{}, metadata.GetPathInSchema()...)
page.DataType = metadata.GetType()
return page, nil
}
func readPage(
thriftReader *thrift.TBufferedTransport,
metadata *parquet.ColumnMetaData,
columnNameIndexMap map[string]int,
schemaElements []*parquet.SchemaElement,
) (page *page, definitionLevels, numRows int64, err error) {
pageHeader, err := readPageHeader(thriftReader)
if err != nil {
return nil, 0, 0, err
}
read := func() (data []byte, err error) {
var repLevelsLen, defLevelsLen int32
var repLevelsBuf, defLevelsBuf []byte
if pageHeader.GetType() == parquet.PageType_DATA_PAGE_V2 {
repLevelsLen = pageHeader.DataPageHeaderV2.GetRepetitionLevelsByteLength()
repLevelsBuf = make([]byte, repLevelsLen)
if _, err = thriftReader.Read(repLevelsBuf); err != nil {
return nil, err
}
defLevelsLen = pageHeader.DataPageHeaderV2.GetDefinitionLevelsByteLength()
defLevelsBuf = make([]byte, defLevelsLen)
if _, err = thriftReader.Read(defLevelsBuf); err != nil {
return nil, err
}
}
dataBuf := make([]byte, pageHeader.GetCompressedPageSize()-repLevelsLen-defLevelsLen)
if _, err = thriftReader.Read(dataBuf); err != nil {
return nil, err
}
if dataBuf, err = compressionCodec(metadata.GetCodec()).uncompress(dataBuf); err != nil {
return nil, err
}
if repLevelsLen == 0 && defLevelsLen == 0 {
return dataBuf, nil
}
if repLevelsLen > 0 {
data = append(data, uint32ToBytes(uint32(repLevelsLen))...)
data = append(data, repLevelsBuf...)
}
if defLevelsLen > 0 {
data = append(data, uint32ToBytes(uint32(defLevelsLen))...)
data = append(data, defLevelsBuf...)
}
data = append(data, dataBuf...)
return data, nil
}
buf, err := read()
if err != nil {
return nil, 0, 0, err
}
path := append([]string{}, metadata.GetPathInSchema()...)
bytesReader := bytes.NewReader(buf)
pageType := pageHeader.GetType()
switch pageType {
case parquet.PageType_INDEX_PAGE:
return nil, 0, 0, fmt.Errorf("page type %v is not supported", parquet.PageType_INDEX_PAGE)
case parquet.PageType_DICTIONARY_PAGE:
page = newDictPage()
page.Header = pageHeader
table := new(table)
table.Path = path
values, err := readValues(bytesReader, metadata.GetType(),
uint64(pageHeader.DictionaryPageHeader.GetNumValues()), 0)
if err != nil {
return nil, 0, 0, err
}
table.Values = getTableValues(values, metadata.GetType())
page.DataTable = table
return page, 0, 0, nil
case parquet.PageType_DATA_PAGE, parquet.PageType_DATA_PAGE_V2:
name := strings.Join(path, ".")
page = newDataPage()
page.Header = pageHeader
maxDefinitionLevel := getMaxDefLevel(columnNameIndexMap, schemaElements, path)
maxRepetitionLevel := getMaxRepLevel(columnNameIndexMap, schemaElements, path)
var numValues uint64
var encodingType parquet.Encoding
if pageHeader.GetType() == parquet.PageType_DATA_PAGE {
numValues = uint64(pageHeader.DataPageHeader.GetNumValues())
encodingType = pageHeader.DataPageHeader.GetEncoding()
} else {
numValues = uint64(pageHeader.DataPageHeaderV2.GetNumValues())
encodingType = pageHeader.DataPageHeaderV2.GetEncoding()
}
var repetitionLevels []int64
if maxRepetitionLevel > 0 {
values, _, err := readDataPageValues(bytesReader, parquet.Encoding_RLE, parquet.Type_INT64,
-1, numValues, getBitWidth(uint64(maxRepetitionLevel)))
if err != nil {
return nil, 0, 0, err
}
if repetitionLevels = values.([]int64); uint64(len(repetitionLevels)) > numValues {
repetitionLevels = repetitionLevels[:numValues]
}
} else {
repetitionLevels = make([]int64, numValues)
}
var definitionLevels []int64
if maxDefinitionLevel > 0 {
values, _, err := readDataPageValues(bytesReader, parquet.Encoding_RLE, parquet.Type_INT64,
-1, numValues, getBitWidth(uint64(maxDefinitionLevel)))
if err != nil {
return nil, 0, 0, err
}
if definitionLevels = values.([]int64); uint64(len(definitionLevels)) > numValues {
definitionLevels = definitionLevels[:numValues]
}
} else {
definitionLevels = make([]int64, numValues)
}
var numNulls uint64
for i := 0; i < len(definitionLevels); i++ {
if definitionLevels[i] != int64(maxDefinitionLevel) {
numNulls++
}
}
var convertedType parquet.ConvertedType = -1
if schemaElements[columnNameIndexMap[name]].IsSetConvertedType() {
convertedType = schemaElements[columnNameIndexMap[name]].GetConvertedType()
}
values, valueType, err := readDataPageValues(bytesReader, encodingType, metadata.GetType(),
convertedType, uint64(len(definitionLevels))-numNulls,
uint64(schemaElements[columnNameIndexMap[name]].GetTypeLength()))
if err != nil {
return nil, 0, 0, err
}
tableValues := getTableValues(values, valueType)
table := new(table)
table.Path = path
table.RepetitionType = schemaElements[columnNameIndexMap[name]].GetRepetitionType()
table.MaxRepetitionLevel = int32(maxRepetitionLevel)
table.MaxDefinitionLevel = int32(maxDefinitionLevel)
table.Values = make([]interface{}, len(definitionLevels))
table.RepetitionLevels = make([]int32, len(definitionLevels))
table.DefinitionLevels = make([]int32, len(definitionLevels))
j := 0
numRows := int64(0)
for i := 0; i < len(definitionLevels); i++ {
table.RepetitionLevels[i] = int32(repetitionLevels[i])
table.DefinitionLevels[i] = int32(definitionLevels[i])
if int(table.DefinitionLevels[i]) == maxDefinitionLevel {
table.Values[i] = tableValues[j]
j++
}
if table.RepetitionLevels[i] == 0 {
numRows++
}
}
page.DataTable = table
return page, int64(len(definitionLevels)), numRows, nil
}
return nil, 0, 0, fmt.Errorf("unknown page type %v", pageType)
}
type page struct {
Header *parquet.PageHeader // Header of a page
DataTable *table // Table to store values
RawData []byte // Compressed data of the page, which is written in parquet file
CompressType parquet.CompressionCodec // Compress type: gzip/snappy/none
DataType parquet.Type // Parquet type of the values in the page
Path []string // Path in schema(include the root)
MaxVal interface{} // Maximum of the values
MinVal interface{} // Minimum of the values
PageSize int32
}
func newPage() *page {
return &page{
Header: parquet.NewPageHeader(),
PageSize: 8 * 1024,
}
}
func newDictPage() *page {
page := newPage()
page.Header.DictionaryPageHeader = parquet.NewDictionaryPageHeader()
return page
}
func newDataPage() *page {
page := newPage()
page.Header.DataPageHeader = parquet.NewDataPageHeader()
return page
}
func (page *page) decode(dictPage *page) {
if dictPage == nil || page == nil || page.Header.DataPageHeader == nil ||
(page.Header.DataPageHeader.Encoding != parquet.Encoding_RLE_DICTIONARY &&
page.Header.DataPageHeader.Encoding != parquet.Encoding_PLAIN_DICTIONARY) {
return
}
for i := 0; i < len(page.DataTable.Values); i++ {
if page.DataTable.Values[i] != nil {
index := page.DataTable.Values[i].(int64)
page.DataTable.Values[i] = dictPage.DataTable.Values[index]
}
}
}
// Get RepetitionLevels and Definitions from RawData
func (page *page) getRLDLFromRawData(columnNameIndexMap map[string]int, schemaElements []*parquet.SchemaElement) (numValues int64, numRows int64, err error) {
bytesReader := bytes.NewReader(page.RawData)
pageType := page.Header.GetType()
var buf []byte
if pageType == parquet.PageType_DATA_PAGE_V2 {
var repLevelsLen, defLevelsLen int32
var repLevelsBuf, defLevelsBuf []byte
repLevelsLen = page.Header.DataPageHeaderV2.GetRepetitionLevelsByteLength()
repLevelsBuf = make([]byte, repLevelsLen)
if _, err = bytesReader.Read(repLevelsBuf); err != nil {
return 0, 0, err
}
defLevelsLen = page.Header.DataPageHeaderV2.GetDefinitionLevelsByteLength()
defLevelsBuf = make([]byte, defLevelsLen)
if _, err = bytesReader.Read(defLevelsBuf); err != nil {
return 0, 0, err
}
dataBuf := make([]byte, len(page.RawData)-int(repLevelsLen)-int(defLevelsLen))
if _, err = bytesReader.Read(dataBuf); err != nil {
return 0, 0, err
}
if repLevelsLen == 0 && defLevelsLen == 0 {
buf = dataBuf
} else {
if repLevelsLen > 0 {
buf = append(buf, uint32ToBytes(uint32(repLevelsLen))...)
buf = append(buf, repLevelsBuf...)
}
if defLevelsLen > 0 {
buf = append(buf, uint32ToBytes(uint32(defLevelsLen))...)
buf = append(buf, defLevelsBuf...)
}
buf = append(buf, dataBuf...)
}
} else {
if buf, err = compressionCodec(page.CompressType).uncompress(page.RawData); err != nil {
return 0, 0, err
}
}
bytesReader = bytes.NewReader(buf)
switch pageType {
case parquet.PageType_DICTIONARY_PAGE:
table := new(table)
table.Path = page.Path
page.DataTable = table
return 0, 0, nil
case parquet.PageType_DATA_PAGE, parquet.PageType_DATA_PAGE_V2:
var numValues uint64
if pageType == parquet.PageType_DATA_PAGE {
numValues = uint64(page.Header.DataPageHeader.GetNumValues())
} else {
numValues = uint64(page.Header.DataPageHeaderV2.GetNumValues())
}
maxDefinitionLevel := getMaxDefLevel(columnNameIndexMap, schemaElements, page.Path)
maxRepetitionLevel := getMaxRepLevel(columnNameIndexMap, schemaElements, page.Path)
var repetitionLevels []int64
if maxRepetitionLevel > 0 {
values, _, err := readDataPageValues(bytesReader, parquet.Encoding_RLE, parquet.Type_INT64,
-1, numValues, getBitWidth(uint64(maxRepetitionLevel)))
if err != nil {
return 0, 0, err
}
if repetitionLevels = values.([]int64); uint64(len(repetitionLevels)) > numValues {
repetitionLevels = repetitionLevels[:numValues]
}
} else {
repetitionLevels = make([]int64, numValues)
}
var definitionLevels []int64
if maxDefinitionLevel > 0 {
values, _, err := readDataPageValues(bytesReader, parquet.Encoding_RLE, parquet.Type_INT64,
-1, numValues, getBitWidth(uint64(maxDefinitionLevel)))
if err != nil {
return 0, 0, err
}
if definitionLevels = values.([]int64); uint64(len(definitionLevels)) > numValues {
definitionLevels = definitionLevels[:numValues]
}
} else {
definitionLevels = make([]int64, numValues)
}
table := new(table)
table.Path = page.Path
name := strings.Join(page.Path, ".")
table.RepetitionType = schemaElements[columnNameIndexMap[name]].GetRepetitionType()
table.MaxRepetitionLevel = int32(maxRepetitionLevel)
table.MaxDefinitionLevel = int32(maxDefinitionLevel)
table.Values = make([]interface{}, len(definitionLevels))
table.RepetitionLevels = make([]int32, len(definitionLevels))
table.DefinitionLevels = make([]int32, len(definitionLevels))
numRows := int64(0)
for i := 0; i < len(definitionLevels); i++ {
table.RepetitionLevels[i] = int32(repetitionLevels[i])
table.DefinitionLevels[i] = int32(definitionLevels[i])
if table.RepetitionLevels[i] == 0 {
numRows++
}
}
page.DataTable = table
page.RawData = buf[len(buf)-bytesReader.Len():]
return int64(numValues), numRows, nil
}
return 0, 0, fmt.Errorf("Unsupported page type %v", pageType)
}
func (page *page) getValueFromRawData(columnNameIndexMap map[string]int, schemaElements []*parquet.SchemaElement) (err error) {
pageType := page.Header.GetType()
switch pageType {
case parquet.PageType_DICTIONARY_PAGE:
bytesReader := bytes.NewReader(page.RawData)
var values interface{}
values, err = readValues(bytesReader, page.DataType,
uint64(page.Header.DictionaryPageHeader.GetNumValues()), 0)
if err != nil {
return err
}
page.DataTable.Values = getTableValues(values, page.DataType)
return nil
case parquet.PageType_DATA_PAGE_V2:
if page.RawData, err = compressionCodec(page.CompressType).uncompress(page.RawData); err != nil {
return err
}
fallthrough
case parquet.PageType_DATA_PAGE:
encodingType := page.Header.DataPageHeader.GetEncoding()
bytesReader := bytes.NewReader(page.RawData)
var numNulls uint64
for i := 0; i < len(page.DataTable.DefinitionLevels); i++ {
if page.DataTable.DefinitionLevels[i] != page.DataTable.MaxDefinitionLevel {
numNulls++
}
}
name := strings.Join(page.DataTable.Path, ".")
var convertedType parquet.ConvertedType = -1
if schemaElements[columnNameIndexMap[name]].IsSetConvertedType() {
convertedType = schemaElements[columnNameIndexMap[name]].GetConvertedType()
}
values, _, err := readDataPageValues(bytesReader, encodingType, page.DataType,
convertedType, uint64(len(page.DataTable.DefinitionLevels))-numNulls,
uint64(schemaElements[columnNameIndexMap[name]].GetTypeLength()))
if err != nil {
return err
}
tableValues := getTableValues(values, page.DataType)
j := 0
for i := 0; i < len(page.DataTable.DefinitionLevels); i++ {
if page.DataTable.DefinitionLevels[i] == page.DataTable.MaxDefinitionLevel {
page.DataTable.Values[i] = tableValues[j]
j++
}
}
page.RawData = []byte{}
return nil
}
return fmt.Errorf("unsupported page type %v", pageType)
}

162
vendor/github.com/minio/parquet-go/parquet.go generated vendored Normal file
View File

@ -0,0 +1,162 @@
/*
* Minio Cloud Storage, (C) 2018 Minio, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package parquet
import (
"encoding/binary"
"encoding/json"
"io"
"git.apache.org/thrift.git/lib/go/thrift"
"github.com/minio/minio-go/pkg/set"
"github.com/minio/parquet-go/gen-go/parquet"
)
// GetReaderFunc - function type returning io.ReadCloser for requested offset/length.
type GetReaderFunc func(offset, length int64) (io.ReadCloser, error)
func footerSize(getReaderFunc GetReaderFunc) (size int64, err error) {
rc, err := getReaderFunc(-8, 4)
if err != nil {
return 0, err
}
defer rc.Close()
buf := make([]byte, 4)
if _, err = io.ReadFull(rc, buf); err != nil {
return 0, err
}
size = int64(binary.LittleEndian.Uint32(buf))
return size, nil
}
func fileMetadata(getReaderFunc GetReaderFunc) (*parquet.FileMetaData, error) {
size, err := footerSize(getReaderFunc)
if err != nil {
return nil, err
}
rc, err := getReaderFunc(-(8 + size), size)
if err != nil {
return nil, err
}
defer rc.Close()
fileMeta := parquet.NewFileMetaData()
pf := thrift.NewTCompactProtocolFactory()
protocol := pf.GetProtocol(thrift.NewStreamTransportR(rc))
err = fileMeta.Read(protocol)
if err != nil {
return nil, err
}
return fileMeta, nil
}
// Value - denotes column value
type Value struct {
Value interface{}
Type parquet.Type
}
// MarshalJSON - encodes to JSON data
func (value Value) MarshalJSON() (data []byte, err error) {
return json.Marshal(value.Value)
}
// File - denotes parquet file.
type File struct {
getReaderFunc GetReaderFunc
schemaElements []*parquet.SchemaElement
rowGroups []*parquet.RowGroup
rowGroupIndex int
columnNames set.StringSet
columns map[string]*column
rowIndex int64
}
// Open - opens parquet file with given column names.
func Open(getReaderFunc GetReaderFunc, columnNames set.StringSet) (*File, error) {
fileMeta, err := fileMetadata(getReaderFunc)
if err != nil {
return nil, err
}
return &File{
getReaderFunc: getReaderFunc,
rowGroups: fileMeta.GetRowGroups(),
schemaElements: fileMeta.GetSchema(),
columnNames: columnNames,
}, nil
}
// Read - reads single record.
func (file *File) Read() (record map[string]Value, err error) {
if file.rowGroupIndex >= len(file.rowGroups) {
return nil, io.EOF
}
if file.columns == nil {
file.columns, err = getColumns(
file.rowGroups[file.rowGroupIndex],
file.columnNames,
file.schemaElements,
file.getReaderFunc,
)
if err != nil {
return nil, err
}
file.rowIndex = 0
}
if file.rowIndex >= file.rowGroups[file.rowGroupIndex].GetNumRows() {
file.rowGroupIndex++
file.Close()
return file.Read()
}
record = make(map[string]Value)
for name := range file.columns {
value, valueType := file.columns[name].read()
record[name] = Value{value, valueType}
}
file.rowIndex++
return record, nil
}
// Close - closes underneath readers.
func (file *File) Close() (err error) {
if file.columns != nil {
return nil
}
for _, column := range file.columns {
column.close()
}
file.columns = nil
file.rowIndex = 0
return nil
}

881
vendor/github.com/minio/parquet-go/parquet.thrift generated vendored Normal file
View File

@ -0,0 +1,881 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* File format description for the parquet file format
*/
namespace cpp parquet
namespace java org.apache.parquet.format
/**
* Types supported by Parquet. These types are intended to be used in combination
* with the encodings to control the on disk storage format.
* For example INT16 is not included as a type since a good encoding of INT32
* would handle this.
*/
enum Type {
BOOLEAN = 0;
INT32 = 1;
INT64 = 2;
INT96 = 3; // deprecated, only used by legacy implementations.
FLOAT = 4;
DOUBLE = 5;
BYTE_ARRAY = 6;
FIXED_LEN_BYTE_ARRAY = 7;
}
/**
* Common types used by frameworks(e.g. hive, pig) using parquet. This helps map
* between types in those frameworks to the base types in parquet. This is only
* metadata and not needed to read or write the data.
*/
enum ConvertedType {
/** a BYTE_ARRAY actually contains UTF8 encoded chars */
UTF8 = 0;
/** a map is converted as an optional field containing a repeated key/value pair */
MAP = 1;
/** a key/value pair is converted into a group of two fields */
MAP_KEY_VALUE = 2;
/** a list is converted into an optional field containing a repeated field for its
* values */
LIST = 3;
/** an enum is converted into a binary field */
ENUM = 4;
/**
* A decimal value.
*
* This may be used to annotate binary or fixed primitive types. The
* underlying byte array stores the unscaled value encoded as two's
* complement using big-endian byte order (the most significant byte is the
* zeroth element). The value of the decimal is the value * 10^{-scale}.
*
* This must be accompanied by a (maximum) precision and a scale in the
* SchemaElement. The precision specifies the number of digits in the decimal
* and the scale stores the location of the decimal point. For example 1.23
* would have precision 3 (3 total digits) and scale 2 (the decimal point is
* 2 digits over).
*/
DECIMAL = 5;
/**
* A Date
*
* Stored as days since Unix epoch, encoded as the INT32 physical type.
*
*/
DATE = 6;
/**
* A time
*
* The total number of milliseconds since midnight. The value is stored
* as an INT32 physical type.
*/
TIME_MILLIS = 7;
/**
* A time.
*
* The total number of microseconds since midnight. The value is stored as
* an INT64 physical type.
*/
TIME_MICROS = 8;
/**
* A date/time combination
*
* Date and time recorded as milliseconds since the Unix epoch. Recorded as
* a physical type of INT64.
*/
TIMESTAMP_MILLIS = 9;
/**
* A date/time combination
*
* Date and time recorded as microseconds since the Unix epoch. The value is
* stored as an INT64 physical type.
*/
TIMESTAMP_MICROS = 10;
/**
* An unsigned integer value.
*
* The number describes the maximum number of meainful data bits in
* the stored value. 8, 16 and 32 bit values are stored using the
* INT32 physical type. 64 bit values are stored using the INT64
* physical type.
*
*/
UINT_8 = 11;
UINT_16 = 12;
UINT_32 = 13;
UINT_64 = 14;
/**
* A signed integer value.
*
* The number describes the maximum number of meainful data bits in
* the stored value. 8, 16 and 32 bit values are stored using the
* INT32 physical type. 64 bit values are stored using the INT64
* physical type.
*
*/
INT_8 = 15;
INT_16 = 16;
INT_32 = 17;
INT_64 = 18;
/**
* An embedded JSON document
*
* A JSON document embedded within a single UTF8 column.
*/
JSON = 19;
/**
* An embedded BSON document
*
* A BSON document embedded within a single BINARY column.
*/
BSON = 20;
/**
* An interval of time
*
* This type annotates data stored as a FIXED_LEN_BYTE_ARRAY of length 12
* This data is composed of three separate little endian unsigned
* integers. Each stores a component of a duration of time. The first
* integer identifies the number of months associated with the duration,
* the second identifies the number of days associated with the duration
* and the third identifies the number of milliseconds associated with
* the provided duration. This duration of time is independent of any
* particular timezone or date.
*/
INTERVAL = 21;
}
/**
* Representation of Schemas
*/
enum FieldRepetitionType {
/** This field is required (can not be null) and each record has exactly 1 value. */
REQUIRED = 0;
/** The field is optional (can be null) and each record has 0 or 1 values. */
OPTIONAL = 1;
/** The field is repeated and can contain 0 or more values */
REPEATED = 2;
}
/**
* Statistics per row group and per page
* All fields are optional.
*/
struct Statistics {
/**
* DEPRECATED: min and max value of the column. Use min_value and max_value.
*
* Values are encoded using PLAIN encoding, except that variable-length byte
* arrays do not include a length prefix.
*
* These fields encode min and max values determined by signed comparison
* only. New files should use the correct order for a column's logical type
* and store the values in the min_value and max_value fields.
*
* To support older readers, these may be set when the column order is
* signed.
*/
1: optional binary max;
2: optional binary min;
/** count of null value in the column */
3: optional i64 null_count;
/** count of distinct values occurring */
4: optional i64 distinct_count;
/**
* Min and max values for the column, determined by its ColumnOrder.
*
* Values are encoded using PLAIN encoding, except that variable-length byte
* arrays do not include a length prefix.
*/
5: optional binary max_value;
6: optional binary min_value;
}
/** Empty structs to use as logical type annotations */
struct StringType {} // allowed for BINARY, must be encoded with UTF-8
struct UUIDType {} // allowed for FIXED[16], must encoded raw UUID bytes
struct MapType {} // see LogicalTypes.md
struct ListType {} // see LogicalTypes.md
struct EnumType {} // allowed for BINARY, must be encoded with UTF-8
struct DateType {} // allowed for INT32
/**
* Logical type to annotate a column that is always null.
*
* Sometimes when discovering the schema of existing data, values are always
* null and the physical type can't be determined. This annotation signals
* the case where the physical type was guessed from all null values.
*/
struct NullType {} // allowed for any physical type, only null values stored
/**
* Decimal logical type annotation
*
* To maintain forward-compatibility in v1, implementations using this logical
* type must also set scale and precision on the annotated SchemaElement.
*
* Allowed for physical types: INT32, INT64, FIXED, and BINARY
*/
struct DecimalType {
1: required i32 scale
2: required i32 precision
}
/** Time units for logical types */
struct MilliSeconds {}
struct MicroSeconds {}
struct NanoSeconds {}
union TimeUnit {
1: MilliSeconds MILLIS
2: MicroSeconds MICROS
3: NanoSeconds NANOS
}
/**
* Timestamp logical type annotation
*
* Allowed for physical types: INT64
*/
struct TimestampType {
1: required bool isAdjustedToUTC
2: required TimeUnit unit
}
/**
* Time logical type annotation
*
* Allowed for physical types: INT32 (millis), INT64 (micros, nanos)
*/
struct TimeType {
1: required bool isAdjustedToUTC
2: required TimeUnit unit
}
/**
* Integer logical type annotation
*
* bitWidth must be 8, 16, 32, or 64.
*
* Allowed for physical types: INT32, INT64
*/
struct IntType {
1: required byte bitWidth
2: required bool isSigned
}
/**
* Embedded JSON logical type annotation
*
* Allowed for physical types: BINARY
*/
struct JsonType {
}
/**
* Embedded BSON logical type annotation
*
* Allowed for physical types: BINARY
*/
struct BsonType {
}
/**
* LogicalType annotations to replace ConvertedType.
*
* To maintain compatibility, implementations using LogicalType for a
* SchemaElement must also set the corresponding ConvertedType from the
* following table.
*/
union LogicalType {
1: StringType STRING // use ConvertedType UTF8
2: MapType MAP // use ConvertedType MAP
3: ListType LIST // use ConvertedType LIST
4: EnumType ENUM // use ConvertedType ENUM
5: DecimalType DECIMAL // use ConvertedType DECIMAL
6: DateType DATE // use ConvertedType DATE
7: TimeType TIME // use ConvertedType TIME_MICROS or TIME_MILLIS
8: TimestampType TIMESTAMP // use ConvertedType TIMESTAMP_MICROS or TIMESTAMP_MILLIS
// 9: reserved for INTERVAL
10: IntType INTEGER // use ConvertedType INT_* or UINT_*
11: NullType UNKNOWN // no compatible ConvertedType
12: JsonType JSON // use ConvertedType JSON
13: BsonType BSON // use ConvertedType BSON
14: UUIDType UUID
}
/**
* Represents a element inside a schema definition.
* - if it is a group (inner node) then type is undefined and num_children is defined
* - if it is a primitive type (leaf) then type is defined and num_children is undefined
* the nodes are listed in depth first traversal order.
*/
struct SchemaElement {
/** Data type for this field. Not set if the current element is a non-leaf node */
1: optional Type type;
/** If type is FIXED_LEN_BYTE_ARRAY, this is the byte length of the vales.
* Otherwise, if specified, this is the maximum bit length to store any of the values.
* (e.g. a low cardinality INT col could have this set to 3). Note that this is
* in the schema, and therefore fixed for the entire file.
*/
2: optional i32 type_length;
/** repetition of the field. The root of the schema does not have a repetition_type.
* All other nodes must have one */
3: optional FieldRepetitionType repetition_type;
/** Name of the field in the schema */
4: required string name;
/** Nested fields. Since thrift does not support nested fields,
* the nesting is flattened to a single list by a depth-first traversal.
* The children count is used to construct the nested relationship.
* This field is not set when the element is a primitive type
*/
5: optional i32 num_children;
/** When the schema is the result of a conversion from another model
* Used to record the original type to help with cross conversion.
*/
6: optional ConvertedType converted_type;
/** Used when this column contains decimal data.
* See the DECIMAL converted type for more details.
*/
7: optional i32 scale
8: optional i32 precision
/** When the original schema supports field ids, this will save the
* original field id in the parquet schema
*/
9: optional i32 field_id;
/**
* The logical type of this SchemaElement
*
* LogicalType replaces ConvertedType, but ConvertedType is still required
* for some logical types to ensure forward-compatibility in format v1.
*/
10: optional LogicalType logicalType
}
/**
* Encodings supported by Parquet. Not all encodings are valid for all types. These
* enums are also used to specify the encoding of definition and repetition levels.
* See the accompanying doc for the details of the more complicated encodings.
*/
enum Encoding {
/** Default encoding.
* BOOLEAN - 1 bit per value. 0 is false; 1 is true.
* INT32 - 4 bytes per value. Stored as little-endian.
* INT64 - 8 bytes per value. Stored as little-endian.
* FLOAT - 4 bytes per value. IEEE. Stored as little-endian.
* DOUBLE - 8 bytes per value. IEEE. Stored as little-endian.
* BYTE_ARRAY - 4 byte length stored as little endian, followed by bytes.
* FIXED_LEN_BYTE_ARRAY - Just the bytes.
*/
PLAIN = 0;
/** Group VarInt encoding for INT32/INT64.
* This encoding is deprecated. It was never used
*/
// GROUP_VAR_INT = 1;
/**
* Deprecated: Dictionary encoding. The values in the dictionary are encoded in the
* plain type.
* in a data page use RLE_DICTIONARY instead.
* in a Dictionary page use PLAIN instead
*/
PLAIN_DICTIONARY = 2;
/** Group packed run length encoding. Usable for definition/repetition levels
* encoding and Booleans (on one bit: 0 is false; 1 is true.)
*/
RLE = 3;
/** Bit packed encoding. This can only be used if the data has a known max
* width. Usable for definition/repetition levels encoding.
*/
BIT_PACKED = 4;
/** Delta encoding for integers. This can be used for int columns and works best
* on sorted data
*/
DELTA_BINARY_PACKED = 5;
/** Encoding for byte arrays to separate the length values and the data. The lengths
* are encoded using DELTA_BINARY_PACKED
*/
DELTA_LENGTH_BYTE_ARRAY = 6;
/** Incremental-encoded byte array. Prefix lengths are encoded using DELTA_BINARY_PACKED.
* Suffixes are stored as delta length byte arrays.
*/
DELTA_BYTE_ARRAY = 7;
/** Dictionary encoding: the ids are encoded using the RLE encoding
*/
RLE_DICTIONARY = 8;
}
/**
* Supported compression algorithms.
*
* Codecs added in 2.4 can be read by readers based on 2.4 and later.
* Codec support may vary between readers based on the format version and
* libraries available at runtime. Gzip, Snappy, and LZ4 codecs are
* widely available, while Zstd and Brotli require additional libraries.
*/
enum CompressionCodec {
UNCOMPRESSED = 0;
SNAPPY = 1;
GZIP = 2;
LZO = 3;
BROTLI = 4; // Added in 2.4
LZ4 = 5; // Added in 2.4
ZSTD = 6; // Added in 2.4
}
enum PageType {
DATA_PAGE = 0;
INDEX_PAGE = 1;
DICTIONARY_PAGE = 2;
DATA_PAGE_V2 = 3;
}
/**
* Enum to annotate whether lists of min/max elements inside ColumnIndex
* are ordered and if so, in which direction.
*/
enum BoundaryOrder {
UNORDERED = 0;
ASCENDING = 1;
DESCENDING = 2;
}
/** Data page header */
struct DataPageHeader {
/** Number of values, including NULLs, in this data page. **/
1: required i32 num_values
/** Encoding used for this data page **/
2: required Encoding encoding
/** Encoding used for definition levels **/
3: required Encoding definition_level_encoding;
/** Encoding used for repetition levels **/
4: required Encoding repetition_level_encoding;
/** Optional statistics for the data in this page**/
5: optional Statistics statistics;
}
struct IndexPageHeader {
/** TODO: **/
}
struct DictionaryPageHeader {
/** Number of values in the dictionary **/
1: required i32 num_values;
/** Encoding using this dictionary page **/
2: required Encoding encoding
/** If true, the entries in the dictionary are sorted in ascending order **/
3: optional bool is_sorted;
}
/**
* New page format allowing reading levels without decompressing the data
* Repetition and definition levels are uncompressed
* The remaining section containing the data is compressed if is_compressed is true
**/
struct DataPageHeaderV2 {
/** Number of values, including NULLs, in this data page. **/
1: required i32 num_values
/** Number of NULL values, in this data page.
Number of non-null = num_values - num_nulls which is also the number of values in the data section **/
2: required i32 num_nulls
/** Number of rows in this data page. which means pages change on record boundaries (r = 0) **/
3: required i32 num_rows
/** Encoding used for data in this page **/
4: required Encoding encoding
// repetition levels and definition levels are always using RLE (without size in it)
/** length of the definition levels */
5: required i32 definition_levels_byte_length;
/** length of the repetition levels */
6: required i32 repetition_levels_byte_length;
/** whether the values are compressed.
Which means the section of the page between
definition_levels_byte_length + repetition_levels_byte_length + 1 and compressed_page_size (included)
is compressed with the compression_codec.
If missing it is considered compressed */
7: optional bool is_compressed = 1;
/** optional statistics for this column chunk */
8: optional Statistics statistics;
}
struct PageHeader {
/** the type of the page: indicates which of the *_header fields is set **/
1: required PageType type
/** Uncompressed page size in bytes (not including this header) **/
2: required i32 uncompressed_page_size
/** Compressed page size in bytes (not including this header) **/
3: required i32 compressed_page_size
/** 32bit crc for the data below. This allows for disabling checksumming in HDFS
* if only a few pages needs to be read
**/
4: optional i32 crc
// Headers for page specific data. One only will be set.
5: optional DataPageHeader data_page_header;
6: optional IndexPageHeader index_page_header;
7: optional DictionaryPageHeader dictionary_page_header;
8: optional DataPageHeaderV2 data_page_header_v2;
}
/**
* Wrapper struct to store key values
*/
struct KeyValue {
1: required string key
2: optional string value
}
/**
* Wrapper struct to specify sort order
*/
struct SortingColumn {
/** The column index (in this row group) **/
1: required i32 column_idx
/** If true, indicates this column is sorted in descending order. **/
2: required bool descending
/** If true, nulls will come before non-null values, otherwise,
* nulls go at the end. */
3: required bool nulls_first
}
/**
* statistics of a given page type and encoding
*/
struct PageEncodingStats {
/** the page type (data/dic/...) **/
1: required PageType page_type;
/** encoding of the page **/
2: required Encoding encoding;
/** number of pages of this type with this encoding **/
3: required i32 count;
}
/**
* Description for column metadata
*/
struct ColumnMetaData {
/** Type of this column **/
1: required Type type
/** Set of all encodings used for this column. The purpose is to validate
* whether we can decode those pages. **/
2: required list<Encoding> encodings
/** Path in schema **/
3: required list<string> path_in_schema
/** Compression codec **/
4: required CompressionCodec codec
/** Number of values in this column **/
5: required i64 num_values
/** total byte size of all uncompressed pages in this column chunk (including the headers) **/
6: required i64 total_uncompressed_size
/** total byte size of all compressed pages in this column chunk (including the headers) **/
7: required i64 total_compressed_size
/** Optional key/value metadata **/
8: optional list<KeyValue> key_value_metadata
/** Byte offset from beginning of file to first data page **/
9: required i64 data_page_offset
/** Byte offset from beginning of file to root index page **/
10: optional i64 index_page_offset
/** Byte offset from the beginning of file to first (only) dictionary page **/
11: optional i64 dictionary_page_offset
/** optional statistics for this column chunk */
12: optional Statistics statistics;
/** Set of all encodings used for pages in this column chunk.
* This information can be used to determine if all data pages are
* dictionary encoded for example **/
13: optional list<PageEncodingStats> encoding_stats;
}
struct ColumnChunk {
/** File where column data is stored. If not set, assumed to be same file as
* metadata. This path is relative to the current file.
**/
1: optional string file_path
/** Byte offset in file_path to the ColumnMetaData **/
2: required i64 file_offset
/** Column metadata for this chunk. This is the same content as what is at
* file_path/file_offset. Having it here has it replicated in the file
* metadata.
**/
3: optional ColumnMetaData meta_data
/** File offset of ColumnChunk's OffsetIndex **/
4: optional i64 offset_index_offset
/** Size of ColumnChunk's OffsetIndex, in bytes **/
5: optional i32 offset_index_length
/** File offset of ColumnChunk's ColumnIndex **/
6: optional i64 column_index_offset
/** Size of ColumnChunk's ColumnIndex, in bytes **/
7: optional i32 column_index_length
}
struct RowGroup {
/** Metadata for each column chunk in this row group.
* This list must have the same order as the SchemaElement list in FileMetaData.
**/
1: required list<ColumnChunk> columns
/** Total byte size of all the uncompressed column data in this row group **/
2: required i64 total_byte_size
/** Number of rows in this row group **/
3: required i64 num_rows
/** If set, specifies a sort ordering of the rows in this RowGroup.
* The sorting columns can be a subset of all the columns.
*/
4: optional list<SortingColumn> sorting_columns
}
/** Empty struct to signal the order defined by the physical or logical type */
struct TypeDefinedOrder {}
/**
* Union to specify the order used for the min_value and max_value fields for a
* column. This union takes the role of an enhanced enum that allows rich
* elements (which will be needed for a collation-based ordering in the future).
*
* Possible values are:
* * TypeDefinedOrder - the column uses the order defined by its logical or
* physical type (if there is no logical type).
*
* If the reader does not support the value of this union, min and max stats
* for this column should be ignored.
*/
union ColumnOrder {
/**
* The sort orders for logical types are:
* UTF8 - unsigned byte-wise comparison
* INT8 - signed comparison
* INT16 - signed comparison
* INT32 - signed comparison
* INT64 - signed comparison
* UINT8 - unsigned comparison
* UINT16 - unsigned comparison
* UINT32 - unsigned comparison
* UINT64 - unsigned comparison
* DECIMAL - signed comparison of the represented value
* DATE - signed comparison
* TIME_MILLIS - signed comparison
* TIME_MICROS - signed comparison
* TIMESTAMP_MILLIS - signed comparison
* TIMESTAMP_MICROS - signed comparison
* INTERVAL - unsigned comparison
* JSON - unsigned byte-wise comparison
* BSON - unsigned byte-wise comparison
* ENUM - unsigned byte-wise comparison
* LIST - undefined
* MAP - undefined
*
* In the absence of logical types, the sort order is determined by the physical type:
* BOOLEAN - false, true
* INT32 - signed comparison
* INT64 - signed comparison
* INT96 (only used for legacy timestamps) - undefined
* FLOAT - signed comparison of the represented value (*)
* DOUBLE - signed comparison of the represented value (*)
* BYTE_ARRAY - unsigned byte-wise comparison
* FIXED_LEN_BYTE_ARRAY - unsigned byte-wise comparison
*
* (*) Because the sorting order is not specified properly for floating
* point values (relations vs. total ordering) the following
* compatibility rules should be applied when reading statistics:
* - If the min is a NaN, it should be ignored.
* - If the max is a NaN, it should be ignored.
* - If the min is +0, the row group may contain -0 values as well.
* - If the max is -0, the row group may contain +0 values as well.
* - When looking for NaN values, min and max should be ignored.
*/
1: TypeDefinedOrder TYPE_ORDER;
}
struct PageLocation {
/** Offset of the page in the file **/
1: required i64 offset
/**
* Size of the page, including header. Sum of compressed_page_size and header
* length
*/
2: required i32 compressed_page_size
/**
* Index within the RowGroup of the first row of the page; this means pages
* change on record boundaries (r = 0).
*/
3: required i64 first_row_index
}
struct OffsetIndex {
/**
* PageLocations, ordered by increasing PageLocation.offset. It is required
* that page_locations[i].first_row_index < page_locations[i+1].first_row_index.
*/
1: required list<PageLocation> page_locations
}
/**
* Description for ColumnIndex.
* Each <array-field>[i] refers to the page at OffsetIndex.page_locations[i]
*/
struct ColumnIndex {
/**
* A list of Boolean values to determine the validity of the corresponding
* min and max values. If true, a page contains only null values, and writers
* have to set the corresponding entries in min_values and max_values to
* byte[0], so that all lists have the same length. If false, the
* corresponding entries in min_values and max_values must be valid.
*/
1: required list<bool> null_pages
/**
* Two lists containing lower and upper bounds for the values of each page.
* These may be the actual minimum and maximum values found on a page, but
* can also be (more compact) values that do not exist on a page. For
* example, instead of storing ""Blart Versenwald III", a writer may set
* min_values[i]="B", max_values[i]="C". Such more compact values must still
* be valid values within the column's logical type. Readers must make sure
* that list entries are populated before using them by inspecting null_pages.
*/
2: required list<binary> min_values
3: required list<binary> max_values
/**
* Stores whether both min_values and max_values are orderd and if so, in
* which direction. This allows readers to perform binary searches in both
* lists. Readers cannot assume that max_values[i] <= min_values[i+1], even
* if the lists are ordered.
*/
4: required BoundaryOrder boundary_order
/** A list containing the number of null values for each page **/
5: optional list<i64> null_counts
}
/**
* Description for file metadata
*/
struct FileMetaData {
/** Version of this file **/
1: required i32 version
/** Parquet schema for this file. This schema contains metadata for all the columns.
* The schema is represented as a tree with a single root. The nodes of the tree
* are flattened to a list by doing a depth-first traversal.
* The column metadata contains the path in the schema for that column which can be
* used to map columns to nodes in the schema.
* The first element is the root **/
2: required list<SchemaElement> schema;
/** Number of rows in this file **/
3: required i64 num_rows
/** Row groups in this file **/
4: required list<RowGroup> row_groups
/** Optional key/value metadata **/
5: optional list<KeyValue> key_value_metadata
/** String for application that wrote this file. This should be in the format
* <Application> version <App Version> (build <App Build Hash>).
* e.g. impala version 1.0 (build 6cf94d29b2b7115df4de2c06e2ab4326d721eb55)
**/
6: optional string created_by
/**
* Sort order used for the min_value and max_value fields of each column in
* this file. Each sort order corresponds to one column, determined by its
* position in the list, matching the position of the column in the schema.
*
* Without column_orders, the meaning of the min_value and max_value fields is
* undefined. To ensure well-defined behaviour, if min_value and max_value are
* written to a Parquet file, column_orders must be written as well.
*
* The obsolete min and max fields are always sorted by signed comparison
* regardless of column_orders.
*/
7: optional list<ColumnOrder> column_orders;
}

Some files were not shown because too many files have changed in this diff Show More