mirror of https://github.com/minio/minio.git
S3 Select API Support for CSV (#6127)
Add support for trivial where clause cases
This commit is contained in:
parent
0e02328c98
commit
7c14cdb60e
4
Makefile
4
Makefile
|
@ -42,8 +42,8 @@ ineffassign:
|
|||
|
||||
cyclo:
|
||||
@echo "Running $@"
|
||||
@${GOPATH}/bin/gocyclo -over 100 cmd
|
||||
@${GOPATH}/bin/gocyclo -over 100 pkg
|
||||
@${GOPATH}/bin/gocyclo -over 200 cmd
|
||||
@${GOPATH}/bin/gocyclo -over 200 pkg
|
||||
|
||||
deadcode:
|
||||
@echo "Running $@"
|
||||
|
|
|
@ -25,6 +25,87 @@ const (
|
|||
responseRequestIDKey = "x-amz-request-id"
|
||||
)
|
||||
|
||||
// CSVFileHeaderInfo -Can be either USE IGNORE OR NONE, defines what to do with
|
||||
// the first row
|
||||
type CSVFileHeaderInfo string
|
||||
|
||||
// Constants for file header info.
|
||||
const (
|
||||
CSVFileHeaderInfoNone CSVFileHeaderInfo = "NONE"
|
||||
CSVFileHeaderInfoIgnore = "IGNORE"
|
||||
CSVFileHeaderInfoUse = "USE"
|
||||
)
|
||||
|
||||
// SelectCompressionType - ONLY GZIP is supported
|
||||
type SelectCompressionType string
|
||||
|
||||
// Constants for compression types under select API.
|
||||
const (
|
||||
SelectCompressionNONE SelectCompressionType = "NONE"
|
||||
SelectCompressionGZIP = "GZIP"
|
||||
SelectCompressionBZIP = "BZIP2"
|
||||
)
|
||||
|
||||
// CSVQuoteFields - Can be either Always or AsNeeded
|
||||
type CSVQuoteFields string
|
||||
|
||||
// Constants for csv quote styles.
|
||||
const (
|
||||
CSVQuoteFieldsAlways CSVQuoteFields = "Always"
|
||||
CSVQuoteFieldsAsNeeded = "AsNeeded"
|
||||
)
|
||||
|
||||
// QueryExpressionType - Currently can only be SQL
|
||||
type QueryExpressionType string
|
||||
|
||||
// Constants for expression type.
|
||||
const (
|
||||
QueryExpressionTypeSQL QueryExpressionType = "SQL"
|
||||
)
|
||||
|
||||
// JSONType determines json input serialization type.
|
||||
type JSONType string
|
||||
|
||||
// Constants for JSONTypes.
|
||||
const (
|
||||
JSONDocumentType JSONType = "Document"
|
||||
JSONStreamType = "Stream"
|
||||
JSONLinesType = "Lines"
|
||||
)
|
||||
|
||||
// ObjectSelectRequest - represents the input select body
|
||||
type ObjectSelectRequest struct {
|
||||
XMLName xml.Name `xml:"SelectObjectContentRequest" json:"-"`
|
||||
Expression string
|
||||
ExpressionType QueryExpressionType
|
||||
InputSerialization struct {
|
||||
CompressionType SelectCompressionType
|
||||
CSV *struct {
|
||||
FileHeaderInfo CSVFileHeaderInfo
|
||||
RecordDelimiter string
|
||||
FieldDelimiter string
|
||||
QuoteCharacter string
|
||||
QuoteEscapeCharacter string
|
||||
Comments string
|
||||
}
|
||||
JSON *struct {
|
||||
Type JSONType
|
||||
}
|
||||
}
|
||||
OutputSerialization struct {
|
||||
CSV *struct {
|
||||
QuoteFields CSVQuoteFields
|
||||
RecordDelimiter string
|
||||
FieldDelimiter string
|
||||
QuoteCharacter string
|
||||
QuoteEscapeCharacter string
|
||||
}
|
||||
JSON *struct {
|
||||
RecordDelimiter string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ObjectIdentifier carries key name for the object to delete.
|
||||
type ObjectIdentifier struct {
|
||||
ObjectName string `xml:"Key"`
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/minio/minio/pkg/dns"
|
||||
"github.com/minio/minio/pkg/event"
|
||||
"github.com/minio/minio/pkg/hash"
|
||||
"github.com/minio/minio/pkg/s3select"
|
||||
)
|
||||
|
||||
// APIError structure
|
||||
|
@ -190,6 +191,7 @@ const (
|
|||
ErrAdminCredentialsMismatch
|
||||
ErrInsecureClientRequest
|
||||
ErrObjectTampered
|
||||
|
||||
ErrHealNotImplemented
|
||||
ErrHealNoSuchProcess
|
||||
ErrHealInvalidClientToken
|
||||
|
@ -197,6 +199,93 @@ const (
|
|||
ErrHealAlreadyRunning
|
||||
ErrHealOverlappingPaths
|
||||
ErrIncorrectContinuationToken
|
||||
|
||||
//S3 Select Errors
|
||||
ErrEmptyRequestBody
|
||||
ErrUnsupportedFunction
|
||||
ErrInvalidExpressionType
|
||||
ErrBusy
|
||||
ErrUnauthorizedAccess
|
||||
ErrExpressionTooLong
|
||||
ErrIllegalSQLFunctionArgument
|
||||
ErrInvalidKeyPath
|
||||
ErrInvalidCompressionFormat
|
||||
ErrInvalidFileHeaderInfo
|
||||
ErrInvalidJSONType
|
||||
ErrInvalidQuoteFields
|
||||
ErrInvalidRequestParameter
|
||||
ErrInvalidDataType
|
||||
ErrInvalidTextEncoding
|
||||
ErrInvalidDataSource
|
||||
ErrInvalidTableAlias
|
||||
ErrMissingRequiredParameter
|
||||
ErrObjectSerializationConflict
|
||||
ErrUnsupportedSQLOperation
|
||||
ErrUnsupportedSQLStructure
|
||||
ErrUnsupportedSyntax
|
||||
ErrUnsupportedRangeHeader
|
||||
ErrLexerInvalidChar
|
||||
ErrLexerInvalidOperator
|
||||
ErrLexerInvalidLiteral
|
||||
ErrLexerInvalidIONLiteral
|
||||
ErrParseExpectedDatePart
|
||||
ErrParseExpectedKeyword
|
||||
ErrParseExpectedTokenType
|
||||
ErrParseExpected2TokenTypes
|
||||
ErrParseExpectedNumber
|
||||
ErrParseExpectedRightParenBuiltinFunctionCall
|
||||
ErrParseExpectedTypeName
|
||||
ErrParseExpectedWhenClause
|
||||
ErrParseUnsupportedToken
|
||||
ErrParseUnsupportedLiteralsGroupBy
|
||||
ErrParseExpectedMember
|
||||
ErrParseUnsupportedSelect
|
||||
ErrParseUnsupportedCase
|
||||
ErrParseUnsupportedCaseClause
|
||||
ErrParseUnsupportedAlias
|
||||
ErrParseUnsupportedSyntax
|
||||
ErrParseUnknownOperator
|
||||
ErrParseInvalidPathComponent
|
||||
ErrParseMissingIdentAfterAt
|
||||
ErrParseUnexpectedOperator
|
||||
ErrParseUnexpectedTerm
|
||||
ErrParseUnexpectedToken
|
||||
ErrParseUnexpectedKeyword
|
||||
ErrParseExpectedExpression
|
||||
ErrParseExpectedLeftParenAfterCast
|
||||
ErrParseExpectedLeftParenValueConstructor
|
||||
ErrParseExpectedLeftParenBuiltinFunctionCall
|
||||
ErrParseExpectedArgumentDelimiter
|
||||
ErrParseCastArity
|
||||
ErrParseInvalidTypeParam
|
||||
ErrParseEmptySelect
|
||||
ErrParseSelectMissingFrom
|
||||
ErrParseExpectedIdentForGroupName
|
||||
ErrParseExpectedIdentForAlias
|
||||
ErrParseUnsupportedCallWithStar
|
||||
ErrParseNonUnaryAgregateFunctionCall
|
||||
ErrParseMalformedJoin
|
||||
ErrParseExpectedIdentForAt
|
||||
ErrParseAsteriskIsNotAloneInSelectList
|
||||
ErrParseCannotMixSqbAndWildcardInSelectList
|
||||
ErrParseInvalidContextForWildcardInSelectList
|
||||
ErrIncorrectSQLFunctionArgumentType
|
||||
ErrValueParseFailure
|
||||
ErrEvaluatorInvalidArguments
|
||||
ErrIntegerOverflow
|
||||
ErrLikeInvalidInputs
|
||||
ErrCastFailed
|
||||
ErrInvalidCast
|
||||
ErrEvaluatorInvalidTimestampFormatPattern
|
||||
ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing
|
||||
ErrEvaluatorTimestampFormatPatternDuplicateFields
|
||||
ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch
|
||||
ErrEvaluatorUnterminatedTimestampFormatPatternToken
|
||||
ErrEvaluatorInvalidTimestampFormatPatternToken
|
||||
ErrEvaluatorInvalidTimestampFormatPatternSymbol
|
||||
ErrEvaluatorBindingDoesNotExist
|
||||
ErrInvalidColumnIndex
|
||||
ErrMissingHeaders
|
||||
)
|
||||
|
||||
// error code to APIError structure, these fields carry respective
|
||||
|
@ -803,6 +892,7 @@ var errorCodeResponse = map[APIErrorCode]APIError{
|
|||
Description: "X-Amz-Expires must be less than a week (in seconds); that is, the given X-Amz-Expires must be less than 604800 seconds",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
|
||||
// Generic Invalid-Request error. Should be used for response errors only for unlikely
|
||||
// corner case errors for which introducing new APIErrorCode is not worth it. LogIf()
|
||||
// should be used to log the error at the source of the error for debugging purposes.
|
||||
|
@ -851,6 +941,432 @@ var errorCodeResponse = map[APIErrorCode]APIError{
|
|||
Description: "The continuation token provided is incorrect",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
//S3 Select API Errors
|
||||
ErrEmptyRequestBody: {
|
||||
Code: "EmptyRequestBody",
|
||||
Description: "Request body cannot be empty.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrUnsupportedFunction: {
|
||||
Code: "UnsupportedFunction",
|
||||
Description: "Encountered an unsupported SQL function.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidDataSource: {
|
||||
Code: "InvalidDataSource",
|
||||
Description: "Invalid data source type. Only CSV and JSON are supported at this time.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidExpressionType: {
|
||||
Code: "InvalidExpressionType",
|
||||
Description: "The ExpressionType is invalid. Only SQL expressions are supported at this time.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrBusy: {
|
||||
Code: "Busy",
|
||||
Description: "The service is unavailable. Please retry.",
|
||||
HTTPStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
ErrUnauthorizedAccess: {
|
||||
Code: "UnauthorizedAccess",
|
||||
Description: "You are not authorized to perform this operation",
|
||||
HTTPStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
ErrExpressionTooLong: {
|
||||
Code: "ExpressionTooLong",
|
||||
Description: "The SQL expression is too long: The maximum byte-length for the SQL expression is 256 KB.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrIllegalSQLFunctionArgument: {
|
||||
Code: "IllegalSqlFunctionArgument",
|
||||
Description: "Illegal argument was used in the SQL function.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidKeyPath: {
|
||||
Code: "InvalidKeyPath",
|
||||
Description: "Key path in the SQL expression is invalid.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidCompressionFormat: {
|
||||
Code: "InvalidCompressionFormat",
|
||||
Description: "The file is not in a supported compression format. Only GZIP is supported at this time.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidFileHeaderInfo: {
|
||||
Code: "InvalidFileHeaderInfo",
|
||||
Description: "The FileHeaderInfo is invalid. Only NONE, USE, and IGNORE are supported.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidJSONType: {
|
||||
Code: "InvalidJsonType",
|
||||
Description: "The JsonType is invalid. Only DOCUMENT and LINES are supported at this time.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidQuoteFields: {
|
||||
Code: "InvalidQuoteFields",
|
||||
Description: "The QuoteFields is invalid. Only ALWAYS and ASNEEDED are supported.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidRequestParameter: {
|
||||
Code: "InvalidRequestParameter",
|
||||
Description: "The value of a parameter in SelectRequest element is invalid. Check the service API documentation and try again.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidDataType: {
|
||||
Code: "InvalidDataType",
|
||||
Description: "The SQL expression contains an invalid data type.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidTextEncoding: {
|
||||
Code: "InvalidTextEncoding",
|
||||
Description: "Invalid encoding type. Only UTF-8 encoding is supported at this time.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidTableAlias: {
|
||||
Code: "InvalidTableAlias",
|
||||
Description: "The SQL expression contains an invalid table alias.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrMissingRequiredParameter: {
|
||||
Code: "MissingRequiredParameter",
|
||||
Description: "The SelectRequest entity is missing a required parameter. Check the service documentation and try again.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrObjectSerializationConflict: {
|
||||
Code: "ObjectSerializationConflict",
|
||||
Description: "The SelectRequest entity can only contain one of CSV or JSON. Check the service documentation and try again.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrUnsupportedSQLOperation: {
|
||||
Code: "UnsupportedSqlOperation",
|
||||
Description: "Encountered an unsupported SQL operation.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrUnsupportedSQLStructure: {
|
||||
Code: "UnsupportedSqlStructure",
|
||||
Description: "Encountered an unsupported SQL structure. Check the SQL Reference.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrUnsupportedSyntax: {
|
||||
Code: "UnsupportedSyntax",
|
||||
Description: "Encountered invalid syntax.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrUnsupportedRangeHeader: {
|
||||
Code: "UnsupportedRangeHeader",
|
||||
Description: "Range header is not supported for this operation.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrLexerInvalidChar: {
|
||||
Code: "LexerInvalidChar",
|
||||
Description: "The SQL expression contains an invalid character.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrLexerInvalidOperator: {
|
||||
Code: "LexerInvalidOperator",
|
||||
Description: "The SQL expression contains an invalid literal.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrLexerInvalidLiteral: {
|
||||
Code: "LexerInvalidLiteral",
|
||||
Description: "The SQL expression contains an invalid operator.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrLexerInvalidIONLiteral: {
|
||||
Code: "LexerInvalidIONLiteral",
|
||||
Description: "The SQL expression contains an invalid operator.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedDatePart: {
|
||||
Code: "ParseExpectedDatePart",
|
||||
Description: "Did not find the expected date part in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedKeyword: {
|
||||
Code: "ParseExpectedKeyword",
|
||||
Description: "Did not find the expected keyword in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedTokenType: {
|
||||
Code: "ParseExpectedTokenType",
|
||||
Description: "Did not find the expected token in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpected2TokenTypes: {
|
||||
Code: "ParseExpected2TokenTypes",
|
||||
Description: "Did not find the expected token in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedNumber: {
|
||||
Code: "ParseExpectedNumber",
|
||||
Description: "Did not find the expected number in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedRightParenBuiltinFunctionCall: {
|
||||
Code: "ParseExpectedRightParenBuiltinFunctionCall",
|
||||
Description: "Did not find the expected right parenthesis character in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedTypeName: {
|
||||
Code: "ParseExpectedTypeName",
|
||||
Description: "Did not find the expected type name in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedWhenClause: {
|
||||
Code: "ParseExpectedWhenClause",
|
||||
Description: "Did not find the expected WHEN clause in the SQL expression. CASE is not supported.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedToken: {
|
||||
Code: "ParseUnsupportedToken",
|
||||
Description: "The SQL expression contains an unsupported token.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedLiteralsGroupBy: {
|
||||
Code: "ParseUnsupportedLiteralsGroupBy",
|
||||
Description: "The SQL expression contains an unsupported use of GROUP BY.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedMember: {
|
||||
Code: "ParseExpectedMember",
|
||||
Description: "The SQL expression contains an unsupported use of MEMBER.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedSelect: {
|
||||
Code: "ParseUnsupportedSelect",
|
||||
Description: "The SQL expression contains an unsupported use of SELECT.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedCase: {
|
||||
Code: "ParseUnsupportedCase",
|
||||
Description: "The SQL expression contains an unsupported use of CASE.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedCaseClause: {
|
||||
Code: "ParseUnsupportedCaseClause",
|
||||
Description: "The SQL expression contains an unsupported use of CASE.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedAlias: {
|
||||
Code: "ParseUnsupportedAlias",
|
||||
Description: "The SQL expression contains an unsupported use of ALIAS.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedSyntax: {
|
||||
Code: "ParseUnsupportedSyntax",
|
||||
Description: "The SQL expression contains unsupported syntax.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnknownOperator: {
|
||||
Code: "ParseUnknownOperator",
|
||||
Description: "The SQL expression contains an invalid operator.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseInvalidPathComponent: {
|
||||
Code: "ParseInvalidPathComponent",
|
||||
Description: "The SQL expression contains an invalid path component.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseMissingIdentAfterAt: {
|
||||
Code: "ParseMissingIdentAfterAt",
|
||||
Description: "Did not find the expected identifier after the @ symbol in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnexpectedOperator: {
|
||||
Code: "ParseUnexpectedOperator",
|
||||
Description: "The SQL expression contains an unexpected operator.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnexpectedTerm: {
|
||||
Code: "ParseUnexpectedTerm",
|
||||
Description: "The SQL expression contains an unexpected term.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnexpectedToken: {
|
||||
Code: "ParseUnexpectedToken",
|
||||
Description: "The SQL expression contains an unexpected token.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnexpectedKeyword: {
|
||||
Code: "ParseUnexpectedKeyword",
|
||||
Description: "The SQL expression contains an unexpected keyword.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedExpression: {
|
||||
Code: "ParseExpectedExpression",
|
||||
Description: "Did not find the expected SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedLeftParenAfterCast: {
|
||||
Code: "ParseExpectedLeftParenAfterCast",
|
||||
Description: "Did not find expected the left parenthesis in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedLeftParenValueConstructor: {
|
||||
Code: "ParseExpectedLeftParenValueConstructor",
|
||||
Description: "Did not find expected the left parenthesis in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedLeftParenBuiltinFunctionCall: {
|
||||
Code: "ParseExpectedLeftParenBuiltinFunctionCall",
|
||||
Description: "Did not find the expected left parenthesis in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedArgumentDelimiter: {
|
||||
Code: "ParseExpectedArgumentDelimiter",
|
||||
Description: "Did not find the expected argument delimiter in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseCastArity: {
|
||||
Code: "ParseCastArity",
|
||||
Description: "The SQL expression CAST has incorrect arity.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseInvalidTypeParam: {
|
||||
Code: "ParseInvalidTypeParam",
|
||||
Description: "The SQL expression contains an invalid parameter value.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseEmptySelect: {
|
||||
Code: "ParseEmptySelect",
|
||||
Description: "The SQL expression contains an empty SELECT.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseSelectMissingFrom: {
|
||||
Code: "ParseSelectMissingFrom",
|
||||
Description: "GROUP is not supported in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedIdentForGroupName: {
|
||||
Code: "ParseExpectedIdentForGroupName",
|
||||
Description: "GROUP is not supported in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedIdentForAlias: {
|
||||
Code: "ParseExpectedIdentForAlias",
|
||||
Description: "Did not find the expected identifier for the alias in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseUnsupportedCallWithStar: {
|
||||
Code: "ParseUnsupportedCallWithStar",
|
||||
Description: "Only COUNT with (*) as a parameter is supported in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseNonUnaryAgregateFunctionCall: {
|
||||
Code: "ParseNonUnaryAgregateFunctionCall",
|
||||
Description: "Only one argument is supported for aggregate functions in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseMalformedJoin: {
|
||||
Code: "ParseMalformedJoin",
|
||||
Description: "JOIN is not supported in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseExpectedIdentForAt: {
|
||||
Code: "ParseExpectedIdentForAt",
|
||||
Description: "Did not find the expected identifier for AT name in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseAsteriskIsNotAloneInSelectList: {
|
||||
Code: "ParseAsteriskIsNotAloneInSelectList",
|
||||
Description: "Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseCannotMixSqbAndWildcardInSelectList: {
|
||||
Code: "ParseCannotMixSqbAndWildcardInSelectList",
|
||||
Description: "Cannot mix [] and * in the same expression in a SELECT list in SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrParseInvalidContextForWildcardInSelectList: {
|
||||
Code: "ParseInvalidContextForWildcardInSelectList",
|
||||
Description: "Invalid use of * in SELECT list in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrIncorrectSQLFunctionArgumentType: {
|
||||
Code: "IncorrectSqlFunctionArgumentType",
|
||||
Description: "Incorrect type of arguments in function call in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrValueParseFailure: {
|
||||
Code: "ValueParseFailure",
|
||||
Description: "Time stamp parse failure in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorInvalidArguments: {
|
||||
Code: "EvaluatorInvalidArguments",
|
||||
Description: "Incorrect number of arguments in the function call in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrIntegerOverflow: {
|
||||
Code: "IntegerOverflow",
|
||||
Description: "Int overflow or underflow in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrLikeInvalidInputs: {
|
||||
Code: "LikeInvalidInputs",
|
||||
Description: "Invalid argument given to the LIKE clause in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrCastFailed: {
|
||||
Code: "CastFailed",
|
||||
Description: "Attempt to convert from one data type to another using CAST failed in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidCast: {
|
||||
Code: "InvalidCast",
|
||||
Description: "Attempt to convert from one data type to another using CAST failed in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorInvalidTimestampFormatPattern: {
|
||||
Code: "EvaluatorInvalidTimestampFormatPattern",
|
||||
Description: "Time stamp format pattern requires additional fields in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing: {
|
||||
Code: "EvaluatorInvalidTimestampFormatPatternSymbolForParsing",
|
||||
Description: "Time stamp format pattern contains a valid format symbol that cannot be applied to time stamp parsing in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorTimestampFormatPatternDuplicateFields: {
|
||||
Code: "EvaluatorTimestampFormatPatternDuplicateFields",
|
||||
Description: "Time stamp format pattern contains multiple format specifiers representing the time stamp field in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch: {
|
||||
Code: "EvaluatorUnterminatedTimestampFormatPatternToken",
|
||||
Description: "Time stamp format pattern contains unterminated token in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorUnterminatedTimestampFormatPatternToken: {
|
||||
Code: "EvaluatorInvalidTimestampFormatPatternToken",
|
||||
Description: "Time stamp format pattern contains an invalid token in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorInvalidTimestampFormatPatternToken: {
|
||||
Code: "EvaluatorInvalidTimestampFormatPatternToken",
|
||||
Description: "Time stamp format pattern contains an invalid token in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorInvalidTimestampFormatPatternSymbol: {
|
||||
Code: "EvaluatorInvalidTimestampFormatPatternSymbol",
|
||||
Description: "Time stamp format pattern contains an invalid symbol in the SQL expression.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrInvalidColumnIndex: {
|
||||
Code: "InvalidColumnIndex",
|
||||
Description: "Column index in the SQL expression is invalid.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrEvaluatorBindingDoesNotExist: {
|
||||
Code: "ErrEvaluatorBindingDoesNotExist",
|
||||
Description: "A column name or a path provided does not exist in the SQL expression",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
ErrMissingHeaders: {
|
||||
Code: "MissingHeaders",
|
||||
Description: "Some headers in the query are missing from the file. Check the file and try again.",
|
||||
HTTPStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
// Add your error structure here.
|
||||
}
|
||||
|
||||
|
@ -900,6 +1416,169 @@ func toAPIErrorCode(err error) (apiErr APIErrorCode) {
|
|||
case context.Canceled, context.DeadlineExceeded:
|
||||
apiErr = ErrOperationTimedOut
|
||||
}
|
||||
switch err {
|
||||
case s3select.ErrBusy:
|
||||
apiErr = ErrBusy
|
||||
case s3select.ErrUnauthorizedAccess:
|
||||
apiErr = ErrUnauthorizedAccess
|
||||
case s3select.ErrExpressionTooLong:
|
||||
apiErr = ErrExpressionTooLong
|
||||
case s3select.ErrIllegalSQLFunctionArgument:
|
||||
apiErr = ErrIllegalSQLFunctionArgument
|
||||
case s3select.ErrInvalidKeyPath:
|
||||
apiErr = ErrInvalidKeyPath
|
||||
case s3select.ErrInvalidCompressionFormat:
|
||||
apiErr = ErrInvalidCompressionFormat
|
||||
case s3select.ErrInvalidFileHeaderInfo:
|
||||
apiErr = ErrInvalidFileHeaderInfo
|
||||
case s3select.ErrInvalidJSONType:
|
||||
apiErr = ErrInvalidJSONType
|
||||
case s3select.ErrInvalidQuoteFields:
|
||||
apiErr = ErrInvalidQuoteFields
|
||||
case s3select.ErrInvalidRequestParameter:
|
||||
apiErr = ErrInvalidRequestParameter
|
||||
case s3select.ErrInvalidDataType:
|
||||
apiErr = ErrInvalidDataType
|
||||
case s3select.ErrInvalidTextEncoding:
|
||||
apiErr = ErrInvalidTextEncoding
|
||||
case s3select.ErrInvalidTableAlias:
|
||||
apiErr = ErrInvalidTableAlias
|
||||
case s3select.ErrMissingRequiredParameter:
|
||||
apiErr = ErrMissingRequiredParameter
|
||||
case s3select.ErrObjectSerializationConflict:
|
||||
apiErr = ErrObjectSerializationConflict
|
||||
case s3select.ErrUnsupportedSQLOperation:
|
||||
apiErr = ErrUnsupportedSQLOperation
|
||||
case s3select.ErrUnsupportedSQLStructure:
|
||||
apiErr = ErrUnsupportedSQLStructure
|
||||
case s3select.ErrUnsupportedSyntax:
|
||||
apiErr = ErrUnsupportedSyntax
|
||||
case s3select.ErrUnsupportedRangeHeader:
|
||||
apiErr = ErrUnsupportedRangeHeader
|
||||
case s3select.ErrLexerInvalidChar:
|
||||
apiErr = ErrLexerInvalidChar
|
||||
case s3select.ErrLexerInvalidOperator:
|
||||
apiErr = ErrLexerInvalidOperator
|
||||
case s3select.ErrLexerInvalidLiteral:
|
||||
apiErr = ErrLexerInvalidLiteral
|
||||
case s3select.ErrLexerInvalidIONLiteral:
|
||||
apiErr = ErrLexerInvalidIONLiteral
|
||||
case s3select.ErrParseExpectedDatePart:
|
||||
apiErr = ErrParseExpectedDatePart
|
||||
case s3select.ErrParseExpectedKeyword:
|
||||
apiErr = ErrParseExpectedKeyword
|
||||
case s3select.ErrParseExpectedTokenType:
|
||||
apiErr = ErrParseExpectedTokenType
|
||||
case s3select.ErrParseExpected2TokenTypes:
|
||||
apiErr = ErrParseExpected2TokenTypes
|
||||
case s3select.ErrParseExpectedNumber:
|
||||
apiErr = ErrParseExpectedNumber
|
||||
case s3select.ErrParseExpectedRightParenBuiltinFunctionCall:
|
||||
apiErr = ErrParseExpectedRightParenBuiltinFunctionCall
|
||||
case s3select.ErrParseExpectedTypeName:
|
||||
apiErr = ErrParseExpectedTypeName
|
||||
case s3select.ErrParseExpectedWhenClause:
|
||||
apiErr = ErrParseExpectedWhenClause
|
||||
case s3select.ErrParseUnsupportedToken:
|
||||
apiErr = ErrParseUnsupportedToken
|
||||
case s3select.ErrParseUnsupportedLiteralsGroupBy:
|
||||
apiErr = ErrParseUnsupportedLiteralsGroupBy
|
||||
case s3select.ErrParseExpectedMember:
|
||||
apiErr = ErrParseExpectedMember
|
||||
case s3select.ErrParseUnsupportedSelect:
|
||||
apiErr = ErrParseUnsupportedSelect
|
||||
case s3select.ErrParseUnsupportedCase:
|
||||
apiErr = ErrParseUnsupportedCase
|
||||
case s3select.ErrParseUnsupportedCaseClause:
|
||||
apiErr = ErrParseUnsupportedCaseClause
|
||||
case s3select.ErrParseUnsupportedAlias:
|
||||
apiErr = ErrParseUnsupportedAlias
|
||||
case s3select.ErrParseUnsupportedSyntax:
|
||||
apiErr = ErrParseUnsupportedSyntax
|
||||
case s3select.ErrParseUnknownOperator:
|
||||
apiErr = ErrParseUnknownOperator
|
||||
case s3select.ErrParseInvalidPathComponent:
|
||||
apiErr = ErrParseInvalidPathComponent
|
||||
case s3select.ErrParseMissingIdentAfterAt:
|
||||
apiErr = ErrParseMissingIdentAfterAt
|
||||
case s3select.ErrParseUnexpectedOperator:
|
||||
apiErr = ErrParseUnexpectedOperator
|
||||
case s3select.ErrParseUnexpectedTerm:
|
||||
apiErr = ErrParseUnexpectedTerm
|
||||
case s3select.ErrParseUnexpectedToken:
|
||||
apiErr = ErrParseUnexpectedToken
|
||||
case s3select.ErrParseUnexpectedKeyword:
|
||||
apiErr = ErrParseUnexpectedKeyword
|
||||
case s3select.ErrParseExpectedExpression:
|
||||
apiErr = ErrParseExpectedExpression
|
||||
case s3select.ErrParseExpectedLeftParenAfterCast:
|
||||
apiErr = ErrParseExpectedLeftParenAfterCast
|
||||
case s3select.ErrParseExpectedLeftParenValueConstructor:
|
||||
apiErr = ErrParseExpectedLeftParenValueConstructor
|
||||
case s3select.ErrParseExpectedLeftParenBuiltinFunctionCall:
|
||||
apiErr = ErrParseExpectedLeftParenBuiltinFunctionCall
|
||||
case s3select.ErrParseExpectedArgumentDelimiter:
|
||||
apiErr = ErrParseExpectedArgumentDelimiter
|
||||
case s3select.ErrParseCastArity:
|
||||
apiErr = ErrParseCastArity
|
||||
case s3select.ErrParseInvalidTypeParam:
|
||||
apiErr = ErrParseInvalidTypeParam
|
||||
case s3select.ErrParseEmptySelect:
|
||||
apiErr = ErrParseEmptySelect
|
||||
case s3select.ErrParseSelectMissingFrom:
|
||||
apiErr = ErrParseSelectMissingFrom
|
||||
case s3select.ErrParseExpectedIdentForGroupName:
|
||||
apiErr = ErrParseExpectedIdentForGroupName
|
||||
case s3select.ErrParseExpectedIdentForAlias:
|
||||
apiErr = ErrParseExpectedIdentForAlias
|
||||
case s3select.ErrParseUnsupportedCallWithStar:
|
||||
apiErr = ErrParseUnsupportedCallWithStar
|
||||
case s3select.ErrParseNonUnaryAgregateFunctionCall:
|
||||
apiErr = ErrParseNonUnaryAgregateFunctionCall
|
||||
case s3select.ErrParseMalformedJoin:
|
||||
apiErr = ErrParseMalformedJoin
|
||||
case s3select.ErrParseExpectedIdentForAt:
|
||||
apiErr = ErrParseExpectedIdentForAt
|
||||
case s3select.ErrParseAsteriskIsNotAloneInSelectList:
|
||||
apiErr = ErrParseAsteriskIsNotAloneInSelectList
|
||||
case s3select.ErrParseCannotMixSqbAndWildcardInSelectList:
|
||||
apiErr = ErrParseCannotMixSqbAndWildcardInSelectList
|
||||
case s3select.ErrParseInvalidContextForWildcardInSelectList:
|
||||
apiErr = ErrParseInvalidContextForWildcardInSelectList
|
||||
case s3select.ErrIncorrectSQLFunctionArgumentType:
|
||||
apiErr = ErrIncorrectSQLFunctionArgumentType
|
||||
case s3select.ErrValueParseFailure:
|
||||
apiErr = ErrValueParseFailure
|
||||
case s3select.ErrIntegerOverflow:
|
||||
apiErr = ErrIntegerOverflow
|
||||
case s3select.ErrLikeInvalidInputs:
|
||||
apiErr = ErrLikeInvalidInputs
|
||||
case s3select.ErrCastFailed:
|
||||
apiErr = ErrCastFailed
|
||||
case s3select.ErrInvalidCast:
|
||||
apiErr = ErrInvalidCast
|
||||
case s3select.ErrEvaluatorInvalidTimestampFormatPattern:
|
||||
apiErr = ErrEvaluatorInvalidTimestampFormatPattern
|
||||
case s3select.ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing:
|
||||
apiErr = ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing
|
||||
case s3select.ErrEvaluatorTimestampFormatPatternDuplicateFields:
|
||||
apiErr = ErrEvaluatorTimestampFormatPatternDuplicateFields
|
||||
case s3select.ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch:
|
||||
apiErr = ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch
|
||||
case s3select.ErrEvaluatorUnterminatedTimestampFormatPatternToken:
|
||||
apiErr = ErrEvaluatorUnterminatedTimestampFormatPatternToken
|
||||
case s3select.ErrEvaluatorInvalidTimestampFormatPatternToken:
|
||||
apiErr = ErrEvaluatorInvalidTimestampFormatPatternToken
|
||||
case s3select.ErrEvaluatorInvalidTimestampFormatPatternSymbol:
|
||||
apiErr = ErrEvaluatorInvalidTimestampFormatPatternSymbol
|
||||
case s3select.ErrInvalidColumnIndex:
|
||||
apiErr = ErrInvalidColumnIndex
|
||||
case s3select.ErrEvaluatorBindingDoesNotExist:
|
||||
apiErr = ErrEvaluatorBindingDoesNotExist
|
||||
case s3select.ErrMissingHeaders:
|
||||
apiErr = ErrMissingHeaders
|
||||
|
||||
}
|
||||
|
||||
if apiErr != ErrNone {
|
||||
// If there was a match in the above switch case.
|
||||
|
|
|
@ -62,6 +62,8 @@ func registerAPIRouter(router *mux.Router) {
|
|||
bucket.Methods("DELETE").Path("/{object:.+}").HandlerFunc(httpTraceAll(api.AbortMultipartUploadHandler)).Queries("uploadId", "{uploadId:.*}")
|
||||
// GetObjectACL - this is a dummy call.
|
||||
bucket.Methods("GET").Path("/{object:.+}").HandlerFunc(httpTraceHdrs(api.GetObjectACLHandler)).Queries("acl", "")
|
||||
// SelectObjectContent
|
||||
bucket.Methods("POST").Path("/{object:.+}").HandlerFunc(httpTraceHdrs(api.SelectObjectContentHandler)).Queries("select", "").Queries("select-type", "2")
|
||||
// GetObject
|
||||
bucket.Methods("GET").Path("/{object:.+}").HandlerFunc(httpTraceHdrs(api.GetObjectHandler))
|
||||
// CopyObject
|
||||
|
|
|
@ -475,7 +475,6 @@ var notimplementedObjectResourceNames = map[string]bool{
|
|||
"policy": true,
|
||||
"tagging": true,
|
||||
"restore": true,
|
||||
"select": true,
|
||||
}
|
||||
|
||||
// Resource handler ServeHTTP() wrapper
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
miniogo "github.com/minio/minio-go"
|
||||
|
@ -40,6 +41,7 @@ import (
|
|||
"github.com/minio/minio/pkg/hash"
|
||||
"github.com/minio/minio/pkg/ioutil"
|
||||
"github.com/minio/minio/pkg/policy"
|
||||
"github.com/minio/minio/pkg/s3select"
|
||||
sha256 "github.com/minio/sha256-simd"
|
||||
"github.com/minio/sio"
|
||||
)
|
||||
|
@ -63,6 +65,191 @@ func setHeadGetRespHeaders(w http.ResponseWriter, reqParams url.Values) {
|
|||
}
|
||||
}
|
||||
|
||||
// SelectObjectContentHandler - GET Object?select
|
||||
// ----------
|
||||
// This implementation of the GET operation retrieves object content based
|
||||
// on an SQL expression. In the request, along with the sql expression, you must
|
||||
// also specify a data serialization format (JSON, CSV) of the object.
|
||||
func (api objectAPIHandlers) SelectObjectContentHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := newContext(r, w, "SelectObject")
|
||||
var object, bucket string
|
||||
vars := mux.Vars(r)
|
||||
bucket = vars["bucket"]
|
||||
object = vars["object"]
|
||||
|
||||
// Fetch object stat info.
|
||||
objectAPI := api.ObjectAPI()
|
||||
if objectAPI == nil {
|
||||
writeErrorResponse(w, ErrServerNotInitialized, r.URL)
|
||||
return
|
||||
}
|
||||
|
||||
getObjectInfo := objectAPI.GetObjectInfo
|
||||
if api.CacheAPI() != nil {
|
||||
getObjectInfo = api.CacheAPI().GetObjectInfo
|
||||
}
|
||||
|
||||
if s3Error := checkRequestAuthType(ctx, r, policy.GetObjectAction, bucket, object); s3Error != ErrNone {
|
||||
if getRequestAuthType(r) == authTypeAnonymous {
|
||||
// As per "Permission" section in
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectGET.html If
|
||||
// the object you request does not exist, the error Amazon S3 returns
|
||||
// depends on whether you also have the s3:ListBucket permission. * If you
|
||||
// have the s3:ListBucket permission on the bucket, Amazon S3 will return
|
||||
// an HTTP status code 404 ("no such key") error. * if you don’t have the
|
||||
// s3:ListBucket permission, Amazon S3 will return an HTTP status code 403
|
||||
// ("access denied") error.`
|
||||
if globalPolicySys.IsAllowed(policy.Args{
|
||||
Action: policy.ListBucketAction,
|
||||
BucketName: bucket,
|
||||
ConditionValues: getConditionValues(r, ""),
|
||||
IsOwner: false,
|
||||
}) {
|
||||
_, err := getObjectInfo(ctx, bucket, object)
|
||||
if toAPIErrorCode(err) == ErrNoSuchKey {
|
||||
s3Error = ErrNoSuchKey
|
||||
}
|
||||
}
|
||||
}
|
||||
writeErrorResponse(w, s3Error, r.URL)
|
||||
return
|
||||
}
|
||||
if r.ContentLength <= 0 {
|
||||
writeErrorResponse(w, ErrEmptyRequestBody, r.URL)
|
||||
return
|
||||
}
|
||||
var selectReq ObjectSelectRequest
|
||||
if err := xmlDecoder(r.Body, &selectReq, r.ContentLength); err != nil {
|
||||
fmt.Println(err)
|
||||
writeErrorResponse(w, ErrMalformedXML, r.URL)
|
||||
return
|
||||
}
|
||||
|
||||
objInfo, err := getObjectInfo(ctx, bucket, object)
|
||||
if err != nil {
|
||||
writeErrorResponse(w, toAPIErrorCode(err), r.URL)
|
||||
return
|
||||
}
|
||||
// Get request range.
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
if rangeHeader != "" {
|
||||
writeErrorResponse(w, ErrUnsupportedRangeHeader, r.URL)
|
||||
return
|
||||
}
|
||||
|
||||
if selectReq.InputSerialization.CompressionType == SelectCompressionGZIP {
|
||||
if !strings.Contains(objInfo.ContentType, "gzip") {
|
||||
writeErrorResponse(w, ErrInvalidDataSource, r.URL)
|
||||
return
|
||||
}
|
||||
}
|
||||
if selectReq.InputSerialization.CompressionType == SelectCompressionBZIP {
|
||||
if !strings.Contains(objInfo.ContentType, "bzip") {
|
||||
writeErrorResponse(w, ErrInvalidDataSource, r.URL)
|
||||
return
|
||||
}
|
||||
}
|
||||
if selectReq.InputSerialization.CompressionType == SelectCompressionNONE ||
|
||||
selectReq.InputSerialization.CompressionType == "" {
|
||||
selectReq.InputSerialization.CompressionType = SelectCompressionNONE
|
||||
if !strings.Contains(objInfo.ContentType, "text/csv") {
|
||||
writeErrorResponse(w, ErrInvalidDataSource, r.URL)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !strings.EqualFold(string(selectReq.ExpressionType), "SQL") {
|
||||
writeErrorResponse(w, ErrInvalidExpressionType, r.URL)
|
||||
return
|
||||
}
|
||||
if len(selectReq.Expression) >= (256 * 1000) {
|
||||
writeErrorResponse(w, ErrExpressionTooLong, r.URL)
|
||||
}
|
||||
if selectReq.InputSerialization.CSV.FileHeaderInfo != CSVFileHeaderInfoUse &&
|
||||
selectReq.InputSerialization.CSV.FileHeaderInfo != CSVFileHeaderInfoNone &&
|
||||
selectReq.InputSerialization.CSV.FileHeaderInfo != CSVFileHeaderInfoIgnore &&
|
||||
selectReq.InputSerialization.CSV.FileHeaderInfo != "" {
|
||||
writeErrorResponse(w, ErrInvalidFileHeaderInfo, r.URL)
|
||||
}
|
||||
if selectReq.OutputSerialization.CSV.QuoteFields != CSVQuoteFieldsAlways &&
|
||||
selectReq.OutputSerialization.CSV.QuoteFields != CSVQuoteFieldsAsNeeded &&
|
||||
selectReq.OutputSerialization.CSV.QuoteFields != "" {
|
||||
writeErrorResponse(w, ErrInvalidQuoteFields, r.URL)
|
||||
}
|
||||
|
||||
getObject := objectAPI.GetObject
|
||||
if api.CacheAPI() != nil && !hasSSECustomerHeader(r.Header) {
|
||||
getObject = api.CacheAPI().GetObject
|
||||
}
|
||||
|
||||
reader, pipewriter := io.Pipe()
|
||||
|
||||
// Get the object.
|
||||
var startOffset int64
|
||||
length := objInfo.Size
|
||||
|
||||
var writer io.Writer
|
||||
writer = pipewriter
|
||||
if objectAPI.IsEncryptionSupported() {
|
||||
if hasSSECustomerHeader(r.Header) {
|
||||
// Response writer should be limited early on for decryption upto required length,
|
||||
// additionally also skipping mod(offset)64KiB boundaries.
|
||||
writer = ioutil.LimitedWriter(writer, startOffset%(64*1024), length)
|
||||
|
||||
writer, startOffset, length, err = DecryptBlocksRequest(writer, r, bucket,
|
||||
object, startOffset, length, objInfo, false)
|
||||
if err != nil {
|
||||
writeErrorResponse(w, toAPIErrorCode(err), r.URL)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
defer reader.Close()
|
||||
if gerr := getObject(ctx, bucket, object, 0, objInfo.Size, writer,
|
||||
objInfo.ETag); gerr != nil {
|
||||
pipewriter.CloseWithError(gerr)
|
||||
return
|
||||
}
|
||||
pipewriter.Close() // Close writer explicitly signaling we wrote all data.
|
||||
}()
|
||||
|
||||
//s3select //Options
|
||||
if selectReq.OutputSerialization.CSV.FieldDelimiter == "" {
|
||||
selectReq.OutputSerialization.CSV.FieldDelimiter = ","
|
||||
}
|
||||
if selectReq.InputSerialization.CSV.FileHeaderInfo == "" {
|
||||
selectReq.InputSerialization.CSV.FileHeaderInfo = CSVFileHeaderInfoNone
|
||||
}
|
||||
if selectReq.InputSerialization.CSV != nil {
|
||||
options := &s3select.Options{
|
||||
HasHeader: selectReq.InputSerialization.CSV.FileHeaderInfo != CSVFileHeaderInfoNone,
|
||||
FieldDelimiter: selectReq.InputSerialization.CSV.FieldDelimiter,
|
||||
Comments: selectReq.InputSerialization.CSV.Comments,
|
||||
Name: "S3Object", // Default table name for all objects
|
||||
ReadFrom: reader,
|
||||
Compressed: string(selectReq.InputSerialization.CompressionType),
|
||||
Expression: selectReq.Expression,
|
||||
OutputFieldDelimiter: selectReq.OutputSerialization.CSV.FieldDelimiter,
|
||||
StreamSize: objInfo.Size,
|
||||
HeaderOpt: selectReq.InputSerialization.CSV.FileHeaderInfo == CSVFileHeaderInfoUse,
|
||||
}
|
||||
s3s, err := s3select.NewInput(options)
|
||||
if err != nil {
|
||||
writeErrorResponse(w, toAPIErrorCode(err), r.URL)
|
||||
return
|
||||
}
|
||||
_, _, _, _, _, _, err = s3s.ParseSelect(selectReq.Expression)
|
||||
if err != nil {
|
||||
writeErrorResponse(w, toAPIErrorCode(err), r.URL)
|
||||
return
|
||||
}
|
||||
if err := s3s.Execute(w); err != nil {
|
||||
logger.LogIf(ctx, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// GetObjectHandler - GET Object
|
||||
// ----------
|
||||
// This implementation of the GET operation retrieves object. To use GET,
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# How to use Minio S3 Select [![Slack](https://slack.minio.io/slack?type=svg)](https://slack.minio.io)
|
||||
|
||||
This document explains current limitations of the Minio S3 Select support.
|
||||
|
||||
## 1. Features to be implemented
|
||||
1). JSON documents as supported Objects
|
||||
|
||||
2). CAST expression
|
||||
|
||||
3). Date Functions
|
||||
|
||||
4). Returning types other than float from aggregation queries.
|
||||
|
||||
5). Bracket and Reversal Notation with SQL Like operator.
|
||||
|
||||
6). SUBSTRING currently is not supported and TRIM only works with default arguments of trim leading and trailing spaces
|
||||
|
||||
## 2. Sample Usage with AWS Boto Client
|
||||
```python
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
import os
|
||||
s3 = boto3.resource('s3',
|
||||
endpoint_url='ENDPOINT',
|
||||
aws_access_key_id='ACCESSKEY',
|
||||
aws_secret_access_key='SECRETKEY',
|
||||
config=Config(signature_version='s3v4'),
|
||||
region_name='us-east-1')
|
||||
s3_client = s3.meta.client
|
||||
|
||||
r = s3_client.select_object_content(
|
||||
Bucket='myBucket',
|
||||
Key='myKey',
|
||||
ExpressionType='SQL',
|
||||
Expression = "SELECT * FROM S3OBJECT AS A",
|
||||
InputSerialization = {'CSV': {"FieldDelimiter": ",","FileHeaderInfo":"USE"}},
|
||||
OutputSerialization = {'CSV': {}},
|
||||
)
|
||||
```
|
||||
## 3. Sample Usage with Minio-Go Client
|
||||
|
||||
```go
|
||||
// Initialize minio client object.
|
||||
minioClient, err := minio.New(endpoint, accessKeyID, secretAccessKey, useSSL)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
input := minio.SelectObjectInput{
|
||||
RecordDelimiter: "\n",
|
||||
FieldDelimiter: ",",
|
||||
FileHeaderInfo: minio.CSVFileHeaderInfoUse,
|
||||
}
|
||||
output := minio.SelectObjectOutput{
|
||||
RecordDelimiter: "\n",
|
||||
FieldDelimiter: ",",
|
||||
}
|
||||
opts := minio.SelectObjectOptions{
|
||||
Type: minio.SelectObjectTypeCSV,
|
||||
Input: input,
|
||||
Output: output,
|
||||
}
|
||||
myReader, err := minioClient.SelectObjectContent(ctx, "sqlselectapi", "player.csv", "Select * from S3OBJECT WHERE last_name = 'James'", opts)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer myReader.Close()
|
||||
|
||||
results, resultWriter := io.Pipe()
|
||||
go func() {
|
||||
defer resultWriter.Close()
|
||||
for event := range myReader.Events() {
|
||||
switch e := event.(type) {
|
||||
case *minio.RecordEvent:
|
||||
resultWriter.Write(e.Payload)
|
||||
case *minio.ProgressEvent:
|
||||
fmt.Println("Progress")
|
||||
case *minio.StatEvent:
|
||||
fmt.Println(string(e.Payload))
|
||||
case *minio.EndEvent:
|
||||
fmt.Println("Ended")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
resReader := csv.NewReader(results)
|
||||
for {
|
||||
record, err := resReader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
// Print out the records
|
||||
fmt.Println(record)
|
||||
}
|
||||
if err := myReader.Err(); err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
```
|
|
@ -0,0 +1,486 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2018 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package s3select
|
||||
|
||||
import "errors"
|
||||
|
||||
//S3 errors below
|
||||
|
||||
// ErrBusy is an error if the service is too busy.
|
||||
var ErrBusy = errors.New("The service is unavailable. Please retry")
|
||||
|
||||
// ErrUnauthorizedAccess is an error if you lack the appropriate credentials to
|
||||
// access the object.
|
||||
var ErrUnauthorizedAccess = errors.New("You are not authorized to perform this operation")
|
||||
|
||||
// ErrExpressionTooLong is an error if your SQL expression too long for
|
||||
// processing.
|
||||
var ErrExpressionTooLong = errors.New("The SQL expression is too long: The maximum byte-length for the SQL expression is 256 KB")
|
||||
|
||||
// ErrIllegalSQLFunctionArgument is an error if you provide an illegal argument
|
||||
// in the SQL function.
|
||||
var ErrIllegalSQLFunctionArgument = errors.New("Illegal argument was used in the SQL function")
|
||||
|
||||
// ErrInvalidColumnIndex is an error if you provide a column index which is not
|
||||
// valid.
|
||||
var ErrInvalidColumnIndex = errors.New("Column index in the SQL expression is invalid")
|
||||
|
||||
// ErrInvalidKeyPath is an error if you provide a key in the SQL expression that
|
||||
// is invalid.
|
||||
var ErrInvalidKeyPath = errors.New("Key path in the SQL expression is invalid")
|
||||
|
||||
// ErrColumnTooLong is an error if your query results in a column that is
|
||||
// greater than the max amount of characters per column of 1mb
|
||||
var ErrColumnTooLong = errors.New("The length of a column in the result is greater than maxCharsPerColumn of 1 MB")
|
||||
|
||||
// ErrOverMaxColumn is an error if the number of columns from the resulting
|
||||
// query is greater than 1Mb.
|
||||
var ErrOverMaxColumn = errors.New("The number of columns in the result is greater than maxColumnNumber of 1 MB")
|
||||
|
||||
// ErrOverMaxRecordSize is an error if the length of a record in the result is
|
||||
// greater than 1 Mb.
|
||||
var ErrOverMaxRecordSize = errors.New("The length of a record in the result is greater than maxCharsPerRecord of 1 MB")
|
||||
|
||||
// ErrMissingHeaders is an error if some of the headers that are requested in
|
||||
// the Select Query are not present in the file.
|
||||
var ErrMissingHeaders = errors.New("Some headers in the query are missing from the file. Check the file and try again")
|
||||
|
||||
// ErrInvalidCompressionFormat is an error if an unsupported compression type is
|
||||
// utilized with the select object query.
|
||||
var ErrInvalidCompressionFormat = errors.New("The file is not in a supported compression format. Only GZIP is supported at this time")
|
||||
|
||||
// ErrTruncatedInput is an error if the object is not compressed properly and an
|
||||
// error occurs during decompression.
|
||||
var ErrTruncatedInput = errors.New("Object decompression failed. Check that the object is properly compressed using the format specified in the request")
|
||||
|
||||
// ErrInvalidFileHeaderInfo is an error if the argument provided to the
|
||||
// FileHeader Argument is incorrect.
|
||||
var ErrInvalidFileHeaderInfo = errors.New("The FileHeaderInfo is invalid. Only NONE, USE, and IGNORE are supported")
|
||||
|
||||
// ErrInvalidJSONType is an error if the json format provided as an argument is
|
||||
// invalid.
|
||||
var ErrInvalidJSONType = errors.New("The JsonType is invalid. Only DOCUMENT and LINES are supported at this time")
|
||||
|
||||
// ErrInvalidQuoteFields is an error if the arguments provided to the
|
||||
// QuoteFields options are not valid.
|
||||
var ErrInvalidQuoteFields = errors.New("The QuoteFields is invalid. Only ALWAYS and ASNEEDED are supported")
|
||||
|
||||
// ErrInvalidRequestParameter is an error if the value of a parameter in the
|
||||
// request element is not valid.
|
||||
var ErrInvalidRequestParameter = errors.New("The value of a parameter in Request element is invalid. Check the service API documentation and try again")
|
||||
|
||||
// ErrCSVParsingError is an error if the CSV presents an error while being
|
||||
// parsed.
|
||||
var ErrCSVParsingError = errors.New("Encountered an Error parsing the CSV file. Check the file and try again")
|
||||
|
||||
// ErrJSONParsingError is an error if while parsing the JSON an error arises.
|
||||
var ErrJSONParsingError = errors.New("Encountered an error parsing the JSON file. Check the file and try again")
|
||||
|
||||
// ErrExternalEvalException is an error that arises if the query can not be
|
||||
// evaluated.
|
||||
var ErrExternalEvalException = errors.New("The query cannot be evaluated. Check the file and try again")
|
||||
|
||||
// ErrInvalidDataType is an error that occurs if the SQL expression contains an
|
||||
// invalid data type.
|
||||
var ErrInvalidDataType = errors.New("The SQL expression contains an invalid data type")
|
||||
|
||||
// ErrUnrecognizedFormatException is an error that arises if there is an invalid
|
||||
// record type.
|
||||
var ErrUnrecognizedFormatException = errors.New("Encountered an invalid record type")
|
||||
|
||||
// ErrInvalidTextEncoding is an error if the text encoding is not valid.
|
||||
var ErrInvalidTextEncoding = errors.New("Invalid encoding type. Only UTF-8 encoding is supported at this time")
|
||||
|
||||
// ErrInvalidTableAlias is an error that arises if the table alias provided in
|
||||
// the SQL expression is invalid.
|
||||
var ErrInvalidTableAlias = errors.New("The SQL expression contains an invalid table alias")
|
||||
|
||||
// ErrMultipleDataSourcesUnsupported is an error that arises if multiple data
|
||||
// sources are provided.
|
||||
var ErrMultipleDataSourcesUnsupported = errors.New("Multiple data sources are not supported")
|
||||
|
||||
// ErrMissingRequiredParameter is an error that arises if a required argument
|
||||
// is omitted from the Request.
|
||||
var ErrMissingRequiredParameter = errors.New("The Request entity is missing a required parameter. Check the service documentation and try again")
|
||||
|
||||
// ErrObjectSerializationConflict is an error that arises if an unsupported
|
||||
// output seralization is provided.
|
||||
var ErrObjectSerializationConflict = errors.New("The Request entity can only contain one of CSV or JSON. Check the service documentation and try again")
|
||||
|
||||
// ErrUnsupportedSQLOperation is an error that arises if an unsupported SQL
|
||||
// operation is used.
|
||||
var ErrUnsupportedSQLOperation = errors.New("Encountered an unsupported SQL operation")
|
||||
|
||||
// ErrUnsupportedSQLStructure is an error that occurs if an unsupported SQL
|
||||
// structure is used.
|
||||
var ErrUnsupportedSQLStructure = errors.New("Encountered an unsupported SQL structure. Check the SQL Reference")
|
||||
|
||||
// ErrUnsupportedStorageClass is an error that occurs if an invalid storace
|
||||
// class is present.
|
||||
var ErrUnsupportedStorageClass = errors.New("Encountered an invalid storage class. Only STANDARD, STANDARD_IA, and ONEZONE_IA storage classes are supported at this time")
|
||||
|
||||
// ErrUnsupportedSyntax is an error that occurs if invalid syntax is present in
|
||||
// the query.
|
||||
var ErrUnsupportedSyntax = errors.New("Encountered invalid syntax")
|
||||
|
||||
// ErrUnsupportedRangeHeader is an error that occurs if a range header is
|
||||
// provided.
|
||||
var ErrUnsupportedRangeHeader = errors.New("Range header is not supported for this operation")
|
||||
|
||||
// ErrLexerInvalidChar is an error that occurs if the SQL expression contains an
|
||||
// invalid character.
|
||||
var ErrLexerInvalidChar = errors.New("The SQL expression contains an invalid character")
|
||||
|
||||
// ErrLexerInvalidOperator is an error that occurs if an invalid operator is
|
||||
// used.
|
||||
var ErrLexerInvalidOperator = errors.New("The SQL expression contains an invalid operator")
|
||||
|
||||
// ErrLexerInvalidLiteral is an error that occurs if an invalid literal is used.
|
||||
var ErrLexerInvalidLiteral = errors.New("The SQL expression contains an invalid literal")
|
||||
|
||||
// ErrLexerInvalidIONLiteral is an error that occurs if an invalid operator is
|
||||
// used
|
||||
var ErrLexerInvalidIONLiteral = errors.New("The SQL expression contains an invalid operator")
|
||||
|
||||
// ErrParseExpectedDatePart is an error that occurs if the date part is not
|
||||
// found in the SQL expression.
|
||||
var ErrParseExpectedDatePart = errors.New("Did not find the expected date part in the SQL expression")
|
||||
|
||||
// ErrParseExpectedKeyword is an error that occurs if the expected keyword was
|
||||
// not found in the expression.
|
||||
var ErrParseExpectedKeyword = errors.New("Did not find the expected keyword in the SQL expression")
|
||||
|
||||
// ErrParseExpectedTokenType is an error that occurs if the expected token is
|
||||
// not found in the SQL expression.
|
||||
var ErrParseExpectedTokenType = errors.New("Did not find the expected token in the SQL expression")
|
||||
|
||||
// ErrParseExpected2TokenTypes is an error that occurs if 2 token types are not
|
||||
// found.
|
||||
var ErrParseExpected2TokenTypes = errors.New("Did not find the expected token in the SQL expression")
|
||||
|
||||
// ErrParseExpectedNumber is an error that occurs if a number is expected but
|
||||
// not found in the expression.
|
||||
var ErrParseExpectedNumber = errors.New("Did not find the expected number in the SQL expression")
|
||||
|
||||
// ErrParseExpectedRightParenBuiltinFunctionCall is an error that occurs if a
|
||||
// right parenthesis is missing.
|
||||
var ErrParseExpectedRightParenBuiltinFunctionCall = errors.New("Did not find the expected right parenthesis character in the SQL expression")
|
||||
|
||||
// ErrParseExpectedTypeName is an error that occurs if a type name is expected
|
||||
// but not found.
|
||||
var ErrParseExpectedTypeName = errors.New("Did not find the expected type name in the SQL expression")
|
||||
|
||||
// ErrParseExpectedWhenClause is an error that occurs if a When clause is
|
||||
// expected but not found.
|
||||
var ErrParseExpectedWhenClause = errors.New("Did not find the expected WHEN clause in the SQL expression. CASE is not supported")
|
||||
|
||||
// ErrParseUnsupportedToken is an error that occurs if the SQL expression
|
||||
// contains an unsupported token.
|
||||
var ErrParseUnsupportedToken = errors.New("The SQL expression contains an unsupported token")
|
||||
|
||||
// ErrParseUnsupportedLiteralsGroupBy is an error that occurs if the SQL
|
||||
// expression has an unsupported use of Group By.
|
||||
var ErrParseUnsupportedLiteralsGroupBy = errors.New("The SQL expression contains an unsupported use of GROUP BY")
|
||||
|
||||
// ErrParseExpectedMember is an error that occurs if there is an unsupported use
|
||||
// of member in the SQL expression.
|
||||
var ErrParseExpectedMember = errors.New("The SQL expression contains an unsupported use of MEMBER")
|
||||
|
||||
// ErrParseUnsupportedSelect is an error that occurs if there is an unsupported
|
||||
// use of Select.
|
||||
var ErrParseUnsupportedSelect = errors.New("The SQL expression contains an unsupported use of SELECT")
|
||||
|
||||
// ErrParseUnsupportedCase is an error that occurs if there is an unsupported
|
||||
// use of case.
|
||||
var ErrParseUnsupportedCase = errors.New("The SQL expression contains an unsupported use of CASE")
|
||||
|
||||
// ErrParseUnsupportedCaseClause is an error that occurs if there is an
|
||||
// unsupported use of case.
|
||||
var ErrParseUnsupportedCaseClause = errors.New("The SQL expression contains an unsupported use of CASE")
|
||||
|
||||
// ErrParseUnsupportedAlias is an error that occurs if there is an unsupported
|
||||
// use of Alias.
|
||||
var ErrParseUnsupportedAlias = errors.New("The SQL expression contains an unsupported use of ALIAS")
|
||||
|
||||
// ErrParseUnsupportedSyntax is an error that occurs if there is an
|
||||
// UnsupportedSyntax in the SQL expression.
|
||||
var ErrParseUnsupportedSyntax = errors.New("The SQL expression contains unsupported syntax")
|
||||
|
||||
// ErrParseUnknownOperator is an error that occurs if there is an invalid
|
||||
// operator present in the SQL expression.
|
||||
var ErrParseUnknownOperator = errors.New("The SQL expression contains an invalid operator")
|
||||
|
||||
// ErrParseInvalidPathComponent is an error that occurs if there is an invalid
|
||||
// path component.
|
||||
var ErrParseInvalidPathComponent = errors.New("The SQL expression contains an invalid path component")
|
||||
|
||||
// ErrParseMissingIdentAfterAt is an error that occurs if the wrong symbol
|
||||
// follows the "@" symbol in the SQL expression.
|
||||
var ErrParseMissingIdentAfterAt = errors.New("Did not find the expected identifier after the @ symbol in the SQL expression")
|
||||
|
||||
// ErrParseUnexpectedOperator is an error that occurs if the SQL expression
|
||||
// contains an unexpected operator.
|
||||
var ErrParseUnexpectedOperator = errors.New("The SQL expression contains an unexpected operator")
|
||||
|
||||
// ErrParseUnexpectedTerm is an error that occurs if the SQL expression contains
|
||||
// an unexpected term.
|
||||
var ErrParseUnexpectedTerm = errors.New("The SQL expression contains an unexpected term")
|
||||
|
||||
// ErrParseUnexpectedToken is an error that occurs if the SQL expression
|
||||
// contains an unexpected token.
|
||||
var ErrParseUnexpectedToken = errors.New("The SQL expression contains an unexpected token")
|
||||
|
||||
// ErrParseUnexpectedKeyword is an error that occurs if the SQL expression
|
||||
// contains an unexpected keyword.
|
||||
var ErrParseUnexpectedKeyword = errors.New("The SQL expression contains an unexpected keyword")
|
||||
|
||||
// ErrParseExpectedExpression is an error that occurs if the SQL expression is
|
||||
// not found.
|
||||
var ErrParseExpectedExpression = errors.New("Did not find the expected SQL expression")
|
||||
|
||||
// ErrParseExpectedLeftParenAfterCast is an error that occurs if the left
|
||||
// parenthesis is missing after a cast in the SQL expression.
|
||||
var ErrParseExpectedLeftParenAfterCast = errors.New("Did not find the expected left parenthesis after CAST in the SQL expression")
|
||||
|
||||
// ErrParseExpectedLeftParenValueConstructor is an error that occurs if the left
|
||||
// parenthesis is not found in the SQL expression.
|
||||
var ErrParseExpectedLeftParenValueConstructor = errors.New("Did not find expected the left parenthesis in the SQL expression")
|
||||
|
||||
// ErrParseExpectedLeftParenBuiltinFunctionCall is an error that occurs if the
|
||||
// left parenthesis is not found in the SQL expression function call.
|
||||
var ErrParseExpectedLeftParenBuiltinFunctionCall = errors.New("Did not find the expected left parenthesis in the SQL expression")
|
||||
|
||||
// ErrParseExpectedArgumentDelimiter is an error that occurs if the argument
|
||||
// delimiter for the SQL expression is not provided.
|
||||
var ErrParseExpectedArgumentDelimiter = errors.New("Did not find the expected argument delimiter in the SQL expression")
|
||||
|
||||
// ErrParseCastArity is an error that occurs because the CAST has incorrect
|
||||
// arity.
|
||||
var ErrParseCastArity = errors.New("The SQL expression CAST has incorrect arity")
|
||||
|
||||
// ErrParseInvalidTypeParam is an error that occurs because there is an invalid
|
||||
// parameter value.
|
||||
var ErrParseInvalidTypeParam = errors.New("The SQL expression contains an invalid parameter value")
|
||||
|
||||
// ErrParseEmptySelect is an error that occurs because the SQL expression
|
||||
// contains an empty Select
|
||||
var ErrParseEmptySelect = errors.New("The SQL expression contains an empty SELECT")
|
||||
|
||||
// ErrParseSelectMissingFrom is an error that occurs because there is a missing
|
||||
// From after the Select List.
|
||||
var ErrParseSelectMissingFrom = errors.New("The SQL expression contains a missing FROM after SELECT list")
|
||||
|
||||
// ErrParseExpectedIdentForGroupName is an error that occurs because Group is
|
||||
// not supported in the SQL expression.
|
||||
var ErrParseExpectedIdentForGroupName = errors.New("GROUP is not supported in the SQL expression")
|
||||
|
||||
// ErrParseExpectedIdentForAlias is an error that occurs if expected identifier
|
||||
// for alias is not in the SQL expression.
|
||||
var ErrParseExpectedIdentForAlias = errors.New("Did not find the expected identifier for the alias in the SQL expression")
|
||||
|
||||
// ErrParseUnsupportedCallWithStar is an error that occurs if COUNT is used with
|
||||
// an argument other than "*".
|
||||
var ErrParseUnsupportedCallWithStar = errors.New("Only COUNT with (*) as a parameter is supported in the SQL expression")
|
||||
|
||||
// ErrParseNonUnaryAgregateFunctionCall is an error that occurs if more than one
|
||||
// argument is provided as an argument for aggregation functions.
|
||||
var ErrParseNonUnaryAgregateFunctionCall = errors.New("Only one argument is supported for aggregate functions in the SQL expression")
|
||||
|
||||
// ErrParseMalformedJoin is an error that occurs if a "join" operation is
|
||||
// attempted in the SQL expression as this is not supported.
|
||||
var ErrParseMalformedJoin = errors.New("JOIN is not supported in the SQL expression")
|
||||
|
||||
// ErrParseExpectedIdentForAt is an error that occurs if after "AT" an Alias
|
||||
// identifier is not provided.
|
||||
var ErrParseExpectedIdentForAt = errors.New("Did not find the expected identifier for AT name in the SQL expression")
|
||||
|
||||
// ErrParseAsteriskIsNotAloneInSelectList is an error that occurs if in addition
|
||||
// to an asterix, more column names are provided as arguments in the SQL
|
||||
// expression.
|
||||
var ErrParseAsteriskIsNotAloneInSelectList = errors.New("Other expressions are not allowed in the SELECT list when '*' is used without dot notation in the SQL expression")
|
||||
|
||||
// ErrParseCannotMixSqbAndWildcardInSelectList is an error that occurs if list
|
||||
// indexing and an asterix are mixed in the SQL expression.
|
||||
var ErrParseCannotMixSqbAndWildcardInSelectList = errors.New("Cannot mix [] and * in the same expression in a SELECT list in SQL expression")
|
||||
|
||||
// ErrParseInvalidContextForWildcardInSelectList is an error that occurs if the
|
||||
// asterix is used improperly within the SQL expression.
|
||||
var ErrParseInvalidContextForWildcardInSelectList = errors.New("Invalid use of * in SELECT list in the SQL expression")
|
||||
|
||||
// ErrEvaluatorBindingDoesNotExist is an error that occurs if a column name or
|
||||
// path provided in the expression does not exist.
|
||||
var ErrEvaluatorBindingDoesNotExist = errors.New("A column name or a path provided does not exist in the SQL expression")
|
||||
|
||||
// ErrIncorrectSQLFunctionArgumentType is an error that occurs if the wrong
|
||||
// argument is provided to a SQL function.
|
||||
var ErrIncorrectSQLFunctionArgumentType = errors.New("Incorrect type of arguments in function call in the SQL expression")
|
||||
|
||||
// ErrAmbiguousFieldName is an error that occurs if the column name which is not
|
||||
// case sensitive, is not descriptive enough to retrieve a singular column.
|
||||
var ErrAmbiguousFieldName = errors.New("Field name matches to multiple fields in the file. Check the SQL expression and the file, and try again")
|
||||
|
||||
// ErrEvaluatorInvalidArguments is an error that occurs if there are not the
|
||||
// correct number of arguments in a functional call to a SQL expression.
|
||||
var ErrEvaluatorInvalidArguments = errors.New("Incorrect number of arguments in the function call in the SQL expression")
|
||||
|
||||
// ErrValueParseFailure is an error that occurs if the Time Stamp is not parsed
|
||||
// correctly in the SQL expression.
|
||||
var ErrValueParseFailure = errors.New("Time stamp parse failure in the SQL expression")
|
||||
|
||||
// ErrIntegerOverflow is an error that occurs if there is an IntegerOverflow or
|
||||
// IntegerUnderFlow in the SQL expression.
|
||||
var ErrIntegerOverflow = errors.New("Int overflow or underflow in the SQL expression")
|
||||
|
||||
// ErrLikeInvalidInputs is an error that occurs if invalid inputs are provided
|
||||
// to the argument LIKE Clause.
|
||||
var ErrLikeInvalidInputs = errors.New("Invalid argument given to the LIKE clause in the SQL expression")
|
||||
|
||||
// ErrCastFailed occurs if the attempt to convert data types in the cast is not
|
||||
// done correctly.
|
||||
var ErrCastFailed = errors.New("Attempt to convert from one data type to another using CAST failed in the SQL expression")
|
||||
|
||||
// ErrInvalidCast is an error that occurs if the attempt to convert data types
|
||||
// failed and was done in an improper fashion.
|
||||
var ErrInvalidCast = errors.New("Attempt to convert from one data type to another using CAST failed in the SQL expression")
|
||||
|
||||
// ErrEvaluatorInvalidTimestampFormatPattern is an error that occurs if the Time
|
||||
// Stamp Format needs more additional fields to be filled.
|
||||
var ErrEvaluatorInvalidTimestampFormatPattern = errors.New("Time stamp format pattern requires additional fields in the SQL expression")
|
||||
|
||||
// ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing is an error that
|
||||
// occurs if the format of the time stamp can not be parsed.
|
||||
var ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing = errors.New("Time stamp format pattern contains a valid format symbol that cannot be applied to time stamp parsing in the SQL expression")
|
||||
|
||||
// ErrEvaluatorTimestampFormatPatternDuplicateFields is an error that occurs if
|
||||
// the time stamp format pattern contains multiple format specifications which
|
||||
// can not be clearly resolved.
|
||||
var ErrEvaluatorTimestampFormatPatternDuplicateFields = errors.New("Time stamp format pattern contains multiple format specifiers representing the time stamp field in the SQL expression")
|
||||
|
||||
//ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch is an error that
|
||||
//occurs if the time stamp format pattern contains a 12 hour day of format but
|
||||
//does not have an AM/PM field.
|
||||
var ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch = errors.New("Time stamp format pattern contains a 12-hour hour of day format symbol but doesn't also contain an AM/PM field, or it contains a 24-hour hour of day format specifier and contains an AM/PM field in the SQL expression")
|
||||
|
||||
// ErrEvaluatorUnterminatedTimestampFormatPatternToken is an error that occurs
|
||||
// if there is an unterminated token in the SQL expression for time stamp
|
||||
// format.
|
||||
var ErrEvaluatorUnterminatedTimestampFormatPatternToken = errors.New("Time stamp format pattern contains unterminated token in the SQL expression")
|
||||
|
||||
// ErrEvaluatorInvalidTimestampFormatPatternToken is an error that occurs if
|
||||
// there is an invalid token in the time stamp format within the SQL expression.
|
||||
var ErrEvaluatorInvalidTimestampFormatPatternToken = errors.New("Time stamp format pattern contains an invalid token in the SQL expression")
|
||||
|
||||
// ErrEvaluatorInvalidTimestampFormatPatternSymbol is an error that occurs if
|
||||
// the time stamp format pattern has an invalid symbol within the SQL
|
||||
// expression.
|
||||
var ErrEvaluatorInvalidTimestampFormatPatternSymbol = errors.New("Time stamp format pattern contains an invalid symbol in the SQL expression")
|
||||
|
||||
// S3 select API errors - TODO fix the errors.
|
||||
var errorCodeResponse = map[error]string{
|
||||
ErrBusy: "Busy",
|
||||
ErrUnauthorizedAccess: "UnauthorizedAccess",
|
||||
ErrExpressionTooLong: "ExpressionTooLong",
|
||||
ErrIllegalSQLFunctionArgument: "IllegalSqlFunctionArgument",
|
||||
ErrInvalidColumnIndex: "InvalidColumnIndex",
|
||||
ErrInvalidKeyPath: "InvalidKeyPath",
|
||||
ErrColumnTooLong: "ColumnTooLong",
|
||||
ErrOverMaxColumn: "OverMaxColumn",
|
||||
ErrOverMaxRecordSize: "OverMaxRecordSize",
|
||||
ErrMissingHeaders: "MissingHeaders",
|
||||
ErrInvalidCompressionFormat: "InvalidCompressionFormat",
|
||||
ErrTruncatedInput: "TruncatedInput",
|
||||
ErrInvalidFileHeaderInfo: "InvalidFileHeaderInfo",
|
||||
ErrInvalidJSONType: "InvalidJsonType",
|
||||
ErrInvalidQuoteFields: "InvalidQuoteFields",
|
||||
ErrInvalidRequestParameter: "InvalidRequestParameter",
|
||||
ErrCSVParsingError: "CSVParsingError",
|
||||
ErrJSONParsingError: "JSONParsingError",
|
||||
ErrExternalEvalException: "ExternalEvalException",
|
||||
ErrInvalidDataType: "InvalidDataType",
|
||||
ErrUnrecognizedFormatException: "UnrecognizedFormatException",
|
||||
ErrInvalidTextEncoding: "InvalidTextEncoding",
|
||||
ErrInvalidTableAlias: "InvalidTableAlias",
|
||||
ErrMultipleDataSourcesUnsupported: "MultipleDataSourcesUnsupported",
|
||||
ErrMissingRequiredParameter: "MissingRequiredParameter",
|
||||
ErrObjectSerializationConflict: "ObjectSerializationConflict",
|
||||
ErrUnsupportedSQLOperation: "UnsupportedSqlOperation",
|
||||
ErrUnsupportedSQLStructure: "UnsupportedSqlStructure",
|
||||
ErrUnsupportedStorageClass: "UnsupportedStorageClass",
|
||||
ErrUnsupportedSyntax: "UnsupportedSyntax",
|
||||
ErrUnsupportedRangeHeader: "UnsupportedRangeHeader",
|
||||
ErrLexerInvalidChar: "LexerInvalidChar",
|
||||
ErrLexerInvalidOperator: "LexerInvalidOperator",
|
||||
ErrLexerInvalidLiteral: "LexerInvalidLiteral",
|
||||
ErrLexerInvalidIONLiteral: "LexerInvalidIONLiteral",
|
||||
ErrParseExpectedDatePart: "ParseExpectedDatePart",
|
||||
ErrParseExpectedKeyword: "ParseExpectedKeyword",
|
||||
ErrParseExpectedTokenType: "ParseExpectedTokenType",
|
||||
ErrParseExpected2TokenTypes: "ParseExpected2TokenTypes",
|
||||
ErrParseExpectedNumber: "ParseExpectedNumber",
|
||||
ErrParseExpectedRightParenBuiltinFunctionCall: "ParseExpectedRightParenBuiltinFunctionCall",
|
||||
ErrParseExpectedTypeName: "ParseExpectedTypeName",
|
||||
ErrParseExpectedWhenClause: "ParseExpectedWhenClause",
|
||||
ErrParseUnsupportedToken: "ParseUnsupportedToken",
|
||||
ErrParseUnsupportedLiteralsGroupBy: "ParseUnsupportedLiteralsGroupBy",
|
||||
ErrParseExpectedMember: "ParseExpectedMember",
|
||||
ErrParseUnsupportedSelect: "ParseUnsupportedSelect",
|
||||
ErrParseUnsupportedCase: "ParseUnsupportedCase:",
|
||||
ErrParseUnsupportedCaseClause: "ParseUnsupportedCaseClause",
|
||||
ErrParseUnsupportedAlias: "ParseUnsupportedAlias",
|
||||
ErrParseUnsupportedSyntax: "ParseUnsupportedSyntax",
|
||||
ErrParseUnknownOperator: "ParseUnknownOperator",
|
||||
ErrParseInvalidPathComponent: "ParseInvalidPathComponent",
|
||||
ErrParseMissingIdentAfterAt: "ParseMissingIdentAfterAt",
|
||||
ErrParseUnexpectedOperator: "ParseUnexpectedOperator",
|
||||
ErrParseUnexpectedTerm: "ParseUnexpectedTerm",
|
||||
ErrParseUnexpectedToken: "ParseUnexpectedToken",
|
||||
ErrParseUnexpectedKeyword: "ParseUnexpectedKeyword",
|
||||
ErrParseExpectedExpression: "ParseExpectedExpression",
|
||||
ErrParseExpectedLeftParenAfterCast: "ParseExpectedLeftParenAfterCast",
|
||||
ErrParseExpectedLeftParenValueConstructor: "ParseExpectedLeftParenValueConstructor",
|
||||
ErrParseExpectedLeftParenBuiltinFunctionCall: "ParseExpectedLeftParenBuiltinFunctionCall",
|
||||
ErrParseExpectedArgumentDelimiter: "ParseExpectedArgumentDelimiter",
|
||||
ErrParseCastArity: "ParseCastArity",
|
||||
ErrParseInvalidTypeParam: "ParseInvalidTypeParam",
|
||||
ErrParseEmptySelect: "ParseEmptySelect",
|
||||
ErrParseSelectMissingFrom: "ParseSelectMissingFrom",
|
||||
ErrParseExpectedIdentForGroupName: "ParseExpectedIdentForGroupName",
|
||||
ErrParseExpectedIdentForAlias: "ParseExpectedIdentForAlias",
|
||||
ErrParseUnsupportedCallWithStar: "ParseUnsupportedCallWithStar",
|
||||
ErrParseNonUnaryAgregateFunctionCall: "ParseNonUnaryAgregateFunctionCall",
|
||||
ErrParseMalformedJoin: "ParseMalformedJoin",
|
||||
ErrParseExpectedIdentForAt: "ParseExpectedIdentForAt",
|
||||
ErrParseAsteriskIsNotAloneInSelectList: "ParseAsteriskIsNotAloneInSelectList",
|
||||
ErrParseCannotMixSqbAndWildcardInSelectList: "ParseCannotMixSqbAndWildcardInSelectList",
|
||||
ErrParseInvalidContextForWildcardInSelectList: "ParseInvalidContextForWildcardInSelectList",
|
||||
ErrEvaluatorBindingDoesNotExist: "EvaluatorBindingDoesNotExist",
|
||||
ErrIncorrectSQLFunctionArgumentType: "IncorrectSqlFunctionArgumentType",
|
||||
ErrAmbiguousFieldName: "AmbiguousFieldName",
|
||||
ErrEvaluatorInvalidArguments: "EvaluatorInvalidArguments",
|
||||
ErrValueParseFailure: "ValueParseFailure",
|
||||
ErrIntegerOverflow: "IntegerOverflow",
|
||||
ErrLikeInvalidInputs: "LikeInvalidInputs",
|
||||
ErrCastFailed: "CastFailed",
|
||||
ErrInvalidCast: "Attempt to convert from one data type to another using CAST failed in the SQL expression.",
|
||||
ErrEvaluatorInvalidTimestampFormatPattern: "EvaluatorInvalidTimestampFormatPattern",
|
||||
ErrEvaluatorInvalidTimestampFormatPatternSymbolForParsing: "EvaluatorInvalidTimestampFormatPatternSymbolForParsing",
|
||||
ErrEvaluatorTimestampFormatPatternDuplicateFields: "EvaluatorTimestampFormatPatternDuplicateFields",
|
||||
ErrEvaluatorTimestampFormatPatternHourClockAmPmMismatch: "EvaluatorTimestampFormatPatternHourClockAmPmMismatch",
|
||||
ErrEvaluatorUnterminatedTimestampFormatPatternToken: "EvaluatorUnterminatedTimestampFormatPatternToken",
|
||||
ErrEvaluatorInvalidTimestampFormatPatternToken: "EvaluatorInvalidTimestampFormatPatternToken",
|
||||
ErrEvaluatorInvalidTimestampFormatPatternSymbol: "EvaluatorInvalidTimestampFormatPatternSymbol",
|
||||
}
|
|
@ -0,0 +1,231 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2018 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package s3select
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
// stringOps is a function which handles the case in a clause if there is a need
|
||||
// to perform a string function
|
||||
func stringOps(myFunc *sqlparser.FuncExpr, record []string, myReturnVal string, columnsMap map[string]int) string {
|
||||
var value string
|
||||
funcName := myFunc.Name.CompliantName()
|
||||
switch tempArg := myFunc.Exprs[0].(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch col := tempArg.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
// myReturnVal is actually the tail recursive value being used in the eval func.
|
||||
return applyStrFunc(myReturnVal, funcName)
|
||||
case *sqlparser.ColName:
|
||||
value = applyStrFunc(record[columnsMap[col.Name.CompliantName()]], funcName)
|
||||
case *sqlparser.SQLVal:
|
||||
value = applyStrFunc(string(col.Val), funcName)
|
||||
}
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// coalOps is a function which decomposes a COALESCE func expr into its struct.
|
||||
func coalOps(myFunc *sqlparser.FuncExpr, record []string, myReturnVal string, columnsMap map[string]int) string {
|
||||
myArgs := make([]string, len(myFunc.Exprs))
|
||||
|
||||
for i := 0; i < len(myFunc.Exprs); i++ {
|
||||
switch tempArg := myFunc.Exprs[i].(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch col := tempArg.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
// myReturnVal is actually the tail recursive value being used in the eval func.
|
||||
return myReturnVal
|
||||
case *sqlparser.ColName:
|
||||
myArgs[i] = record[columnsMap[col.Name.CompliantName()]]
|
||||
case *sqlparser.SQLVal:
|
||||
myArgs[i] = string(col.Val)
|
||||
}
|
||||
}
|
||||
}
|
||||
return processCoalNoIndex(myArgs)
|
||||
}
|
||||
|
||||
// nullOps is a function which decomposes a NullIf func expr into its struct.
|
||||
func nullOps(myFunc *sqlparser.FuncExpr, record []string, myReturnVal string, columnsMap map[string]int) string {
|
||||
myArgs := make([]string, 2)
|
||||
|
||||
for i := 0; i < len(myFunc.Exprs); i++ {
|
||||
switch tempArg := myFunc.Exprs[i].(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch col := tempArg.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
return myReturnVal
|
||||
case *sqlparser.ColName:
|
||||
myArgs[i] = record[columnsMap[col.Name.CompliantName()]]
|
||||
case *sqlparser.SQLVal:
|
||||
myArgs[i] = string(col.Val)
|
||||
}
|
||||
}
|
||||
}
|
||||
return processNullIf(myArgs)
|
||||
}
|
||||
|
||||
// isValidString is a function that ensures the current index is one with a
|
||||
// StrFunc
|
||||
func isValidFunc(myList []int, index int) bool {
|
||||
if myList == nil {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(myList); i++ {
|
||||
if myList[i] == index {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// processNullIf is a function that evaluates a given NULLIF clause.
|
||||
func processNullIf(nullStore []string) string {
|
||||
nullValOne := nullStore[0]
|
||||
nullValTwo := nullStore[1]
|
||||
if nullValOne == nullValTwo {
|
||||
return ""
|
||||
}
|
||||
return nullValOne
|
||||
}
|
||||
|
||||
// processCoalNoIndex is a function which evaluates a given COALESCE clause.
|
||||
func processCoalNoIndex(coalStore []string) string {
|
||||
for i := 0; i < len(coalStore); i++ {
|
||||
if coalStore[i] != "null" && coalStore[i] != "missing" && coalStore[i] != "" {
|
||||
return coalStore[i]
|
||||
}
|
||||
}
|
||||
return "null"
|
||||
}
|
||||
|
||||
// evaluateFuncExpr is a function that allows for tail recursive evaluation of
|
||||
// nested function expressions.
|
||||
func evaluateFuncExpr(myVal *sqlparser.FuncExpr, myReturnVal string, myRecord []string, columnsMap map[string]int) string {
|
||||
if myVal == nil {
|
||||
return myReturnVal
|
||||
}
|
||||
// retrieve all the relevant arguments of the function
|
||||
var mySubFunc []*sqlparser.FuncExpr
|
||||
mySubFunc = make([]*sqlparser.FuncExpr, len(myVal.Exprs))
|
||||
for i := 0; i < len(myVal.Exprs); i++ {
|
||||
switch col := myVal.Exprs[i].(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch temp := col.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
mySubFunc[i] = temp
|
||||
}
|
||||
}
|
||||
}
|
||||
// Need to do tree recursion so as to explore all possible directions of the
|
||||
// nested function recursion
|
||||
for i := 0; i < len(mySubFunc); i++ {
|
||||
if supportedString(myVal.Name.CompliantName()) {
|
||||
if mySubFunc != nil {
|
||||
return stringOps(myVal, myRecord, evaluateFuncExpr(mySubFunc[i], myReturnVal, myRecord, columnsMap), columnsMap)
|
||||
}
|
||||
return stringOps(myVal, myRecord, myReturnVal, columnsMap)
|
||||
} else if strings.ToUpper(myVal.Name.CompliantName()) == "NULLIF" {
|
||||
if mySubFunc != nil {
|
||||
return nullOps(myVal, myRecord, evaluateFuncExpr(mySubFunc[i], myReturnVal, myRecord, columnsMap), columnsMap)
|
||||
}
|
||||
return nullOps(myVal, myRecord, myReturnVal, columnsMap)
|
||||
} else if strings.ToUpper(myVal.Name.CompliantName()) == "COALESCE" {
|
||||
if mySubFunc != nil {
|
||||
return coalOps(myVal, myRecord, evaluateFuncExpr(mySubFunc[i], myReturnVal, myRecord, columnsMap), columnsMap)
|
||||
}
|
||||
return coalOps(myVal, myRecord, myReturnVal, columnsMap)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// evaluateFuncErr is a function that flags errors in nested functions.
|
||||
func (reader *Input) evaluateFuncErr(myVal *sqlparser.FuncExpr) error {
|
||||
if myVal == nil {
|
||||
return nil
|
||||
}
|
||||
if !supportedFunc(myVal.Name.CompliantName()) {
|
||||
return ErrUnsupportedSQLOperation
|
||||
}
|
||||
for i := 0; i < len(myVal.Exprs); i++ {
|
||||
switch tempArg := myVal.Exprs[i].(type) {
|
||||
case *sqlparser.StarExpr:
|
||||
return ErrParseUnsupportedCallWithStar
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch col := tempArg.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
if err := reader.evaluateFuncErr(col); err != nil {
|
||||
return err
|
||||
}
|
||||
case *sqlparser.ColName:
|
||||
if err := reader.colNameErrs([]string{col.Name.CompliantName()}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// evaluateIsExpr is a function for evaluating expressions of the form "column
|
||||
// is ...."
|
||||
func evaluateIsExpr(myFunc *sqlparser.IsExpr, row []string, columnNames map[string]int, alias string) (bool, error) {
|
||||
operator := myFunc.Operator
|
||||
var colName string
|
||||
var myVal string
|
||||
switch myIs := myFunc.Expr.(type) {
|
||||
// case for literal val
|
||||
case *sqlparser.SQLVal:
|
||||
myVal = string(myIs.Val)
|
||||
// case for nested func val
|
||||
case *sqlparser.FuncExpr:
|
||||
myVal = evaluateFuncExpr(myIs, "", row, columnNames)
|
||||
// case for col val
|
||||
case *sqlparser.ColName:
|
||||
colName = cleanCol(myIs.Name.CompliantName(), alias)
|
||||
}
|
||||
// case if it is a col val
|
||||
if colName != "" {
|
||||
myVal = row[columnNames[colName]]
|
||||
}
|
||||
// case to evaluate is null
|
||||
if strings.ToLower(operator) == "is null" {
|
||||
return myVal == "", nil
|
||||
}
|
||||
// case to evaluate is not null
|
||||
if strings.ToLower(operator) == "is not null" {
|
||||
return myVal != "", nil
|
||||
}
|
||||
return false, ErrUnsupportedSQLOperation
|
||||
}
|
||||
|
||||
// supportedString is a function that checks whether the function is a supported
|
||||
// string one
|
||||
func supportedString(strFunc string) bool {
|
||||
return stringInSlice(strings.ToUpper(strFunc), []string{"TRIM", "SUBSTRING", "CHAR_LENGTH", "CHARACTER_LENGTH", "LOWER", "UPPER"})
|
||||
}
|
||||
|
||||
// supportedFunc is a function that checks whether the function is a supported
|
||||
// S3 one.
|
||||
func supportedFunc(strFunc string) bool {
|
||||
return stringInSlice(strings.ToUpper(strFunc), []string{"TRIM", "SUBSTRING", "CHAR_LENGTH", "CHARACTER_LENGTH", "LOWER", "UPPER", "COALESCE", "NULLIF"})
|
||||
}
|
|
@ -0,0 +1,754 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2018 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package s3select
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
// This function processes size so that we can calculate bytes BytesProcessed.
|
||||
func processSize(myrecord []string) int64 {
|
||||
if len(myrecord) > 0 {
|
||||
var size int64
|
||||
size = int64(len(myrecord)-1) + 1
|
||||
for i := range myrecord {
|
||||
size += int64(len(myrecord[i]))
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// This function finds whether a string is in a list
|
||||
func stringInSlice(x string, list []string) bool {
|
||||
for _, y := range list {
|
||||
if x == y {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// This function returns the index of a string in a list
|
||||
func stringIndex(a string, list []string) int {
|
||||
for i := range list {
|
||||
if list[i] == a {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Returns a true or false, whether a string can be represented as an int.
|
||||
func representsInt(s string) bool {
|
||||
if _, err := strconv.Atoi(s); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// The function below processes the where clause into an acutal boolean given a
|
||||
// row
|
||||
func matchesMyWhereClause(row []string, columnNames map[string]int, alias string, whereClause interface{}) (bool, error) {
|
||||
// This particular logic deals with the details of casting, e.g if we have to
|
||||
// cast a column of string numbers into int's for comparison.
|
||||
var conversionColumn string
|
||||
var operator string
|
||||
var operand interface{}
|
||||
if fmt.Sprintf("%v", whereClause) == "false" {
|
||||
return false, nil
|
||||
}
|
||||
switch expr := whereClause.(type) {
|
||||
case *sqlparser.IsExpr:
|
||||
return evaluateIsExpr(expr, row, columnNames, alias)
|
||||
case *sqlparser.RangeCond:
|
||||
operator = expr.Operator
|
||||
if operator != "between" && operator != "not between" {
|
||||
return false, ErrUnsupportedSQLOperation
|
||||
}
|
||||
if operator == "not between" {
|
||||
myResult, err := evaluateBetween(expr, alias, row, columnNames)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !myResult, nil
|
||||
}
|
||||
myResult, err := evaluateBetween(expr, alias, row, columnNames)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return myResult, nil
|
||||
case *sqlparser.ComparisonExpr:
|
||||
operator = expr.Operator
|
||||
switch right := expr.Right.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
operand = evaluateFuncExpr(right, "", row, columnNames)
|
||||
case *sqlparser.SQLVal:
|
||||
var err error
|
||||
operand, err = evaluateParserType(right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
var myVal string
|
||||
myVal = ""
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
myVal = evaluateFuncExpr(left, "", row, columnNames)
|
||||
conversionColumn = ""
|
||||
case *sqlparser.ColName:
|
||||
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
|
||||
}
|
||||
if representsInt(conversionColumn) {
|
||||
intCol, err := strconv.Atoi(conversionColumn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// Subtract 1 out because the index starts at 1 for Amazon instead of 0.
|
||||
return evaluateOperator(row[intCol-1], operator, operand)
|
||||
}
|
||||
if myVal != "" {
|
||||
return evaluateOperator(myVal, operator, operand)
|
||||
}
|
||||
return evaluateOperator(row[columnNames[conversionColumn]], operator, operand)
|
||||
case *sqlparser.AndExpr:
|
||||
var leftVal bool
|
||||
var rightVal bool
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
temp, err := matchesMyWhereClause(row, columnNames, alias, left)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
leftVal = temp
|
||||
}
|
||||
switch right := expr.Right.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
temp, err := matchesMyWhereClause(row, columnNames, alias, right)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
rightVal = temp
|
||||
}
|
||||
return (rightVal && leftVal), nil
|
||||
case *sqlparser.OrExpr:
|
||||
var leftVal bool
|
||||
var rightVal bool
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
leftVal, _ = matchesMyWhereClause(row, columnNames, alias, left)
|
||||
|
||||
}
|
||||
switch right := expr.Right.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
rightVal, _ = matchesMyWhereClause(row, columnNames, alias, right)
|
||||
}
|
||||
return (rightVal || leftVal), nil
|
||||
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
func applyStrFunc(rawArg string, funcName string) string {
|
||||
switch strings.ToUpper(funcName) {
|
||||
case "TRIM":
|
||||
// parser has an issue which does not allow it to support Trim with other
|
||||
// arguments
|
||||
return strings.Trim(rawArg, " ")
|
||||
case "SUBSTRING":
|
||||
// TODO parser has an issue which does not support substring
|
||||
return rawArg
|
||||
case "CHAR_LENGTH":
|
||||
return strconv.Itoa(len(rawArg))
|
||||
case "CHARACTER_LENGTH":
|
||||
return strconv.Itoa(len(rawArg))
|
||||
case "LOWER":
|
||||
return strings.ToLower(rawArg)
|
||||
case "UPPER":
|
||||
return strings.ToUpper(rawArg)
|
||||
}
|
||||
return rawArg
|
||||
|
||||
}
|
||||
|
||||
// This is a really important function it actually evaluates the boolean
|
||||
// statement and therefore actually returns a bool, it functions as the lowest
|
||||
// level of the state machine.
|
||||
func evaluateOperator(myTblVal string, operator string, operand interface{}) (bool, error) {
|
||||
if err := checkValidOperator(operator); err != nil {
|
||||
return false, err
|
||||
}
|
||||
myRecordVal := checkStringType(myTblVal)
|
||||
myVal := reflect.ValueOf(myRecordVal)
|
||||
myOp := reflect.ValueOf(operand)
|
||||
|
||||
switch {
|
||||
case myVal.Kind() == reflect.String && myOp.Kind() == reflect.String:
|
||||
return stringEval(myVal.String(), operator, myOp.String())
|
||||
case myVal.Kind() == reflect.Float64 && myOp.Kind() == reflect.Float64:
|
||||
return floatEval(myVal.Float(), operator, myOp.Float())
|
||||
case myVal.Kind() == reflect.Int && myOp.Kind() == reflect.Int:
|
||||
return intEval(myVal.Int(), operator, myOp.Int())
|
||||
case myVal.Kind() == reflect.Int && myOp.Kind() == reflect.String:
|
||||
stringVs := strconv.Itoa(int(myVal.Int()))
|
||||
return stringEval(stringVs, operator, myOp.String())
|
||||
case myVal.Kind() == reflect.Float64 && myOp.Kind() == reflect.String:
|
||||
stringVs := strconv.FormatFloat(myVal.Float(), 'f', 6, 64)
|
||||
return stringEval(stringVs, operator, myOp.String())
|
||||
case myVal.Kind() != myOp.Kind():
|
||||
return false, nil
|
||||
}
|
||||
return false, ErrUnsupportedSyntax
|
||||
}
|
||||
|
||||
// checkValidOperator ensures that the current operator is supported
|
||||
func checkValidOperator(operator string) error {
|
||||
listOfOps := []string{">", "<", "=", "<=", ">=", "!=", "like"}
|
||||
for i := range listOfOps {
|
||||
if operator == listOfOps[i] {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrParseUnknownOperator
|
||||
}
|
||||
|
||||
// checkStringType converts the value from the csv to the appropriate one.
|
||||
func checkStringType(myTblVal string) interface{} {
|
||||
myInt, isInt := strconv.Atoi(myTblVal)
|
||||
myFloat, isFloat := strconv.ParseFloat(myTblVal, 64)
|
||||
if isInt == nil {
|
||||
return myInt
|
||||
} else if isFloat == nil {
|
||||
return myFloat
|
||||
} else {
|
||||
return myTblVal
|
||||
}
|
||||
}
|
||||
|
||||
// stringEval is for evaluating the state of string comparison.
|
||||
func stringEval(myRecordVal string, operator string, myOperand string) (bool, error) {
|
||||
switch operator {
|
||||
case ">":
|
||||
return myRecordVal > myOperand, nil
|
||||
case "<":
|
||||
return myRecordVal < myOperand, nil
|
||||
case "=":
|
||||
return myRecordVal == myOperand, nil
|
||||
case "<=":
|
||||
return myRecordVal <= myOperand, nil
|
||||
case ">=":
|
||||
return myRecordVal >= myOperand, nil
|
||||
case "!=":
|
||||
return myRecordVal != myOperand, nil
|
||||
case "like":
|
||||
return likeConvert(myOperand, myRecordVal)
|
||||
}
|
||||
return false, ErrUnsupportedSyntax
|
||||
}
|
||||
|
||||
// intEval is for evaluating integer comparisons.
|
||||
func intEval(myRecordVal int64, operator string, myOperand int64) (bool, error) {
|
||||
|
||||
switch operator {
|
||||
case ">":
|
||||
return myRecordVal > myOperand, nil
|
||||
case "<":
|
||||
return myRecordVal < myOperand, nil
|
||||
case "=":
|
||||
return myRecordVal == myOperand, nil
|
||||
case "<=":
|
||||
return myRecordVal <= myOperand, nil
|
||||
case ">=":
|
||||
return myRecordVal >= myOperand, nil
|
||||
case "!=":
|
||||
return myRecordVal != myOperand, nil
|
||||
}
|
||||
return false, ErrUnsupportedSyntax
|
||||
}
|
||||
|
||||
// floatEval is for evaluating the comparison of floats.
|
||||
func floatEval(myRecordVal float64, operator string, myOperand float64) (bool, error) {
|
||||
// Basically need some logic thats like, if the types dont match check for a cast
|
||||
switch operator {
|
||||
case ">":
|
||||
return myRecordVal > myOperand, nil
|
||||
case "<":
|
||||
return myRecordVal < myOperand, nil
|
||||
case "=":
|
||||
return myRecordVal == myOperand, nil
|
||||
case "<=":
|
||||
return myRecordVal <= myOperand, nil
|
||||
case ">=":
|
||||
return myRecordVal >= myOperand, nil
|
||||
case "!=":
|
||||
return myRecordVal != myOperand, nil
|
||||
}
|
||||
return false, ErrUnsupportedSyntax
|
||||
}
|
||||
|
||||
// prefixMatch allows for matching a prefix only like query e.g a%
|
||||
func prefixMatch(pattern string, record string) bool {
|
||||
for i := 0; i < len(pattern)-1; i++ {
|
||||
if pattern[i] != record[i] && pattern[i] != byte('_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// suffixMatch allows for matching a suffix only like query e.g %an
|
||||
func suffixMatch(pattern string, record string) bool {
|
||||
for i := len(pattern) - 1; i > 0; i-- {
|
||||
if pattern[i] != record[len(record)-(len(pattern)-i)] && pattern[i] != byte('_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// This function is for evaluating select statements which are case sensitive
|
||||
func likeConvert(pattern string, record string) (bool, error) {
|
||||
// If pattern is empty just return false
|
||||
if pattern == "" || record == "" {
|
||||
return false, nil
|
||||
}
|
||||
// for suffix match queries e.g %a
|
||||
if len(pattern) >= 2 && pattern[0] == byte('%') && strings.Count(pattern, "%") == 1 {
|
||||
return suffixMatch(pattern, record), nil
|
||||
}
|
||||
// for prefix match queries e.g a%
|
||||
if len(pattern) >= 2 && pattern[len(pattern)-1] == byte('%') && strings.Count(pattern, "%") == 1 {
|
||||
return prefixMatch(pattern, record), nil
|
||||
}
|
||||
charCount := 0
|
||||
currPos := 0
|
||||
// Loop through the pattern so that a boolean can be returned
|
||||
for i := 0; i < len(pattern); i++ {
|
||||
if pattern[i] == byte('_') {
|
||||
// if its an underscore it can be anything so shift current position for
|
||||
// pattern and string
|
||||
charCount++
|
||||
// if there have been more characters in the pattern than record, clearly
|
||||
// there should be a return
|
||||
if i != len(pattern)-1 {
|
||||
if pattern[i+1] != byte('%') && pattern[i+1] != byte('_') {
|
||||
if currPos != len(record)-1 && pattern[i+1] != record[currPos+1] {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if charCount > len(record) {
|
||||
return false, nil
|
||||
}
|
||||
// if the pattern has been fully evaluated, then just return.
|
||||
if len(pattern) == i+1 {
|
||||
return true, nil
|
||||
}
|
||||
i++
|
||||
currPos++
|
||||
}
|
||||
if pattern[i] == byte('%') || pattern[i] == byte('*') {
|
||||
// if there is a wildcard then want to return true if its last and flag it.
|
||||
if currPos == len(record) {
|
||||
return false, nil
|
||||
}
|
||||
if i+1 == len(pattern) {
|
||||
return true, nil
|
||||
}
|
||||
} else {
|
||||
charCount++
|
||||
matched := false
|
||||
// iterate through the pattern and check if there is a match for the
|
||||
// character
|
||||
for currPos < len(record) {
|
||||
if record[currPos] == pattern[i] || pattern[i] == byte('_') {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
currPos++
|
||||
}
|
||||
currPos++
|
||||
// if the character did not match then return should occur.
|
||||
if !matched {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if charCount > len(record) {
|
||||
return false, nil
|
||||
}
|
||||
if currPos < len(record) {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// TrimQuotes allows the following to occur select "name", we need to trim the
|
||||
// quotes to reference our map of columnNames.
|
||||
func trimQuotes(s string) string {
|
||||
if len(s) >= 2 {
|
||||
if c := s[len(s)-1]; s[0] == c && (c == '"') {
|
||||
return s[1 : len(s)-1]
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// cleanCol cleans a column name from the parser so that the name is returned to
|
||||
// original.
|
||||
func cleanCol(myCol string, alias string) string {
|
||||
if len(myCol) <= 0 {
|
||||
return myCol
|
||||
}
|
||||
if !strings.HasPrefix(myCol, alias) && myCol[0] == '_' {
|
||||
myCol = alias + myCol
|
||||
}
|
||||
|
||||
if strings.Contains(myCol, ".") {
|
||||
myCol = strings.Replace(myCol, alias+"._", "", len(myCol))
|
||||
}
|
||||
myCol = strings.Replace(myCol, alias+"_", "", len(myCol))
|
||||
return myCol
|
||||
}
|
||||
|
||||
// evaluateBetween is a function which evaluates a Between Clause.
|
||||
func evaluateBetween(betweenExpr *sqlparser.RangeCond, alias string, record []string, columnNames map[string]int) (bool, error) {
|
||||
var colToVal interface{}
|
||||
var colFromVal interface{}
|
||||
var conversionColumn string
|
||||
var funcName string
|
||||
switch colTo := betweenExpr.To.(type) {
|
||||
case sqlparser.Expr:
|
||||
switch colToMyVal := colTo.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
var temp string
|
||||
temp = stringOps(colToMyVal, record, "", columnNames)
|
||||
colToVal = []byte(temp)
|
||||
case *sqlparser.SQLVal:
|
||||
var err error
|
||||
colToVal, err = evaluateParserType(colToMyVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
switch colFrom := betweenExpr.From.(type) {
|
||||
case sqlparser.Expr:
|
||||
switch colFromMyVal := colFrom.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
colFromVal = stringOps(colFromMyVal, record, "", columnNames)
|
||||
case *sqlparser.SQLVal:
|
||||
var err error
|
||||
colFromVal, err = evaluateParserType(colFromMyVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
var myFuncVal string
|
||||
myFuncVal = ""
|
||||
switch left := betweenExpr.Left.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
myFuncVal = evaluateFuncExpr(left, "", record, columnNames)
|
||||
conversionColumn = ""
|
||||
case *sqlparser.ColName:
|
||||
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
|
||||
}
|
||||
|
||||
toGreater, err := evaluateOperator(fmt.Sprintf("%v", colToVal), ">", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if toGreater {
|
||||
return evalBetweenGreater(conversionColumn, record, funcName, columnNames, colFromVal, colToVal, myFuncVal)
|
||||
}
|
||||
return evalBetweenLess(conversionColumn, record, funcName, columnNames, colFromVal, colToVal, myFuncVal)
|
||||
}
|
||||
|
||||
// evalBetweenLess is a function which evaluates the between given that the
|
||||
// FROM is > than the TO.
|
||||
func evalBetweenLess(conversionColumn string, record []string, funcName string, columnNames map[string]int, colFromVal interface{}, colToVal interface{}, myCoalVal string) (bool, error) {
|
||||
if representsInt(conversionColumn) {
|
||||
myIndex, _ := strconv.Atoi(conversionColumn)
|
||||
// Subtract 1 out because the index starts at 1 for Amazon instead of 0.
|
||||
myVal, err := evaluateOperator(record[myIndex-1], "<=", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var myOtherVal bool
|
||||
myOtherVal, err = evaluateOperator(fmt.Sprintf("%v", colToVal), "<=", checkStringType(record[myIndex-1]))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return (myVal && myOtherVal), nil
|
||||
}
|
||||
if myCoalVal != "" {
|
||||
myVal, err := evaluateOperator(myCoalVal, "<=", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var myOtherVal bool
|
||||
myOtherVal, err = evaluateOperator(fmt.Sprintf("%v", colToVal), "<=", checkStringType(myCoalVal))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return (myVal && myOtherVal), nil
|
||||
}
|
||||
myVal, err := evaluateOperator(record[columnNames[conversionColumn]], "<=", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var myOtherVal bool
|
||||
myOtherVal, err = evaluateOperator(fmt.Sprintf("%v", colToVal), "<=", checkStringType(record[columnNames[conversionColumn]]))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return (myVal && myOtherVal), nil
|
||||
}
|
||||
|
||||
// evalBetweenGreater is a function which evaluates the between given that the
|
||||
// TO is > than the FROM.
|
||||
func evalBetweenGreater(conversionColumn string, record []string, funcName string, columnNames map[string]int, colFromVal interface{}, colToVal interface{}, myCoalVal string) (bool, error) {
|
||||
if representsInt(conversionColumn) {
|
||||
myIndex, _ := strconv.Atoi(conversionColumn)
|
||||
myVal, err := evaluateOperator(record[myIndex-1], ">=", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var myOtherVal bool
|
||||
myOtherVal, err = evaluateOperator(fmt.Sprintf("%v", colToVal), ">=", checkStringType(record[myIndex-1]))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return (myVal && myOtherVal), nil
|
||||
}
|
||||
if myCoalVal != "" {
|
||||
myVal, err := evaluateOperator(myCoalVal, ">=", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var myOtherVal bool
|
||||
myOtherVal, err = evaluateOperator(fmt.Sprintf("%v", colToVal), ">=", checkStringType(myCoalVal))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return (myVal && myOtherVal), nil
|
||||
}
|
||||
myVal, err := evaluateOperator(record[columnNames[conversionColumn]], ">=", colFromVal)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var myOtherVal bool
|
||||
myOtherVal, err = evaluateOperator(fmt.Sprintf("%v", colToVal), ">=", checkStringType(record[columnNames[conversionColumn]]))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return (myVal && myOtherVal), nil
|
||||
}
|
||||
|
||||
// whereClauseNameErrs is a function which returns an error if there is a column
|
||||
// in the where clause which does not exist.
|
||||
func (reader *Input) whereClauseNameErrs(whereClause interface{}, alias string) error {
|
||||
var conversionColumn string
|
||||
switch expr := whereClause.(type) {
|
||||
// case for checking errors within a clause of the form "col_name is ..."
|
||||
case *sqlparser.IsExpr:
|
||||
switch myCol := expr.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
if err := reader.evaluateFuncErr(myCol); err != nil {
|
||||
return err
|
||||
}
|
||||
case *sqlparser.ColName:
|
||||
conversionColumn = cleanCol(myCol.Name.CompliantName(), alias)
|
||||
}
|
||||
case *sqlparser.RangeCond:
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
if err := reader.evaluateFuncErr(left); err != nil {
|
||||
return err
|
||||
}
|
||||
case *sqlparser.ColName:
|
||||
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
|
||||
}
|
||||
case *sqlparser.ComparisonExpr:
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
if err := reader.evaluateFuncErr(left); err != nil {
|
||||
return err
|
||||
}
|
||||
case *sqlparser.ColName:
|
||||
conversionColumn = cleanCol(left.Name.CompliantName(), alias)
|
||||
}
|
||||
case *sqlparser.AndExpr:
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return reader.whereClauseNameErrs(left, alias)
|
||||
}
|
||||
switch right := expr.Right.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return reader.whereClauseNameErrs(right, alias)
|
||||
}
|
||||
case *sqlparser.OrExpr:
|
||||
switch left := expr.Left.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return reader.whereClauseNameErrs(left, alias)
|
||||
}
|
||||
switch right := expr.Right.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
return reader.whereClauseNameErrs(right, alias)
|
||||
}
|
||||
}
|
||||
if conversionColumn != "" {
|
||||
return reader.colNameErrs([]string{conversionColumn})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// qualityCheck ensures the row has enough separators.
|
||||
func qualityCheck(row string, amountOfSep int, sep string) string {
|
||||
for i := 0; i < amountOfSep; i++ {
|
||||
row = row + sep
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// writeRow helps to write the row regardless of how many entries.
|
||||
func writeRow(myRow string, myEntry string, delimiter string, numOfReqCols int) string {
|
||||
if myEntry == "" && len(myRow) == 0 && numOfReqCols == 1 {
|
||||
return myEntry
|
||||
}
|
||||
if myEntry == "" && len(myRow) == 0 {
|
||||
return myEntry + delimiter
|
||||
}
|
||||
if len(myRow) == 1 && myRow[0] == ',' {
|
||||
return myRow + myEntry
|
||||
}
|
||||
if len(myRow) == 0 {
|
||||
return myEntry
|
||||
}
|
||||
return myRow + delimiter + myEntry
|
||||
}
|
||||
|
||||
// colNameErrs is a function which makes sure that the headers are requested are
|
||||
// present in the file otherwise it throws an error.
|
||||
func (reader *Input) colNameErrs(columnNames []string) error {
|
||||
for i := 0; i < len(columnNames); i++ {
|
||||
if columnNames[i] == "" {
|
||||
continue
|
||||
}
|
||||
if !representsInt(columnNames[i]) && !reader.options.HeaderOpt {
|
||||
return ErrInvalidColumnIndex
|
||||
}
|
||||
if representsInt(columnNames[i]) {
|
||||
tempInt, _ := strconv.Atoi(columnNames[i])
|
||||
if tempInt > len(reader.Header()) || tempInt == 0 {
|
||||
return ErrInvalidColumnIndex
|
||||
}
|
||||
} else {
|
||||
if reader.options.HeaderOpt && !stringInSlice(columnNames[i], reader.Header()) {
|
||||
return ErrMissingHeaders
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// aggFuncToStr converts an array of floats into a properly formatted string.
|
||||
func (reader *Input) aggFuncToStr(myAggVals []float64) string {
|
||||
myRow := strconv.FormatFloat(myAggVals[0], 'f', 6, 64)
|
||||
for i := 1; i < len(myAggVals); i++ {
|
||||
aggregateval := strconv.FormatFloat(myAggVals[i], 'f', 6, 64)
|
||||
myRow = myRow + reader.options.OutputFieldDelimiter + aggregateval
|
||||
}
|
||||
return myRow
|
||||
}
|
||||
|
||||
// checkForDuplicates ensures we do not have an ambigious column name.
|
||||
func checkForDuplicates(columns []string, columnsMap map[string]int, hasDuplicates map[string]bool, lowercaseColumnsMap map[string]int) error {
|
||||
for i := 0; i < len(columns); i++ {
|
||||
columns[i] = strings.Replace(columns[i], " ", "_", len(columns[i]))
|
||||
if _, exist := columnsMap[columns[i]]; exist {
|
||||
return ErrAmbiguousFieldName
|
||||
}
|
||||
columnsMap[columns[i]] = i
|
||||
// This checks that if a key has already been put into the map, that we're
|
||||
// setting its appropriate value in has duplicates to be true.
|
||||
if _, exist := lowercaseColumnsMap[strings.ToLower(columns[i])]; exist {
|
||||
hasDuplicates[strings.ToLower(columns[i])] = true
|
||||
} else {
|
||||
lowercaseColumnsMap[strings.ToLower(columns[i])] = i
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// evaluateParserType is a function that takes a SQL value and returns it as an
|
||||
// interface converted into the appropriate value.
|
||||
func evaluateParserType(col *sqlparser.SQLVal) (interface{}, error) {
|
||||
colDataType := col.Type
|
||||
var val interface{}
|
||||
switch colDataType {
|
||||
case 0:
|
||||
val = string(col.Val)
|
||||
case 1:
|
||||
intVersion, isInt := strconv.Atoi(string(col.Val))
|
||||
if isInt != nil {
|
||||
return nil, ErrIntegerOverflow
|
||||
}
|
||||
val = intVersion
|
||||
case 2:
|
||||
floatVersion, isFloat := strconv.ParseFloat(string(col.Val), 64)
|
||||
if isFloat != nil {
|
||||
return nil, ErrIntegerOverflow
|
||||
}
|
||||
val = floatVersion
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// parseErrs is the function which handles all the errors that could occur
|
||||
// through use of function arguments such as column names in NULLIF
|
||||
func (reader *Input) parseErrs(columnNames []string, whereClause interface{}, alias string, myFuncs *SelectFuncs) error {
|
||||
// Below code cleans up column names.
|
||||
reader.processColumnNames(columnNames, alias)
|
||||
if columnNames[0] != "*" {
|
||||
if err := reader.colNameErrs(columnNames); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Below code ensures the whereClause has no errors.
|
||||
if whereClause != nil {
|
||||
tempClause := whereClause
|
||||
if err := reader.whereClauseNameErrs(tempClause, alias); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := 0; i < len(myFuncs.funcExpr); i++ {
|
||||
if myFuncs.funcExpr[i] == nil {
|
||||
continue
|
||||
}
|
||||
if err := reader.evaluateFuncErr(myFuncs.funcExpr[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,381 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2018 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package s3select
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/bzip2"
|
||||
"encoding/csv"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
|
||||
gzip "github.com/klauspost/pgzip"
|
||||
)
|
||||
|
||||
const (
|
||||
// progressTime is the time interval for which a progress message is sent.
|
||||
progressTime time.Duration = 60 * time.Second
|
||||
// continuationTime is the time interval for which a continuation message is
|
||||
// sent.
|
||||
continuationTime time.Duration = 5 * time.Second
|
||||
)
|
||||
|
||||
// progress represents a struct that represents the format for XML of the
|
||||
// progress messages
|
||||
type progress struct {
|
||||
BytesScanned int64 `xml:"BytesScanned"`
|
||||
BytesProcessed int64 `xml:"BytesProcessed"`
|
||||
BytesReturned int64 `xml:"BytesReturned"`
|
||||
Xmlns string `xml:"xmlns,attr"`
|
||||
}
|
||||
|
||||
// stats represents a struct that represents the format for XML of the stat
|
||||
// messages
|
||||
type stats struct {
|
||||
BytesScanned int64 `xml:"BytesScanned"`
|
||||
BytesProcessed int64 `xml:"BytesProcessed"`
|
||||
BytesReturned int64 `xml:"BytesReturned"`
|
||||
Xmlns string `xml:"xmlns,attr"`
|
||||
}
|
||||
|
||||
// StatInfo is a struct that represents the
|
||||
type statInfo struct {
|
||||
BytesScanned int64
|
||||
BytesReturned int64
|
||||
BytesProcessed int64
|
||||
}
|
||||
|
||||
// Input represents a record producing input from a formatted file or pipe.
|
||||
type Input struct {
|
||||
options *Options
|
||||
reader *csv.Reader
|
||||
firstRow []string
|
||||
header []string
|
||||
minOutputLength int
|
||||
stats *statInfo
|
||||
}
|
||||
|
||||
// Options options are passed to the underlying encoding/csv reader.
|
||||
type Options struct {
|
||||
// HasHeader when true, will treat the first row as a header row.
|
||||
HasHeader bool
|
||||
|
||||
// FieldDelimiter is the string that fields are delimited by.
|
||||
FieldDelimiter string
|
||||
|
||||
// Comments is the string the first character of a line of
|
||||
// text matches the comment character.
|
||||
Comments string
|
||||
|
||||
// Name of the table that is used for querying
|
||||
Name string
|
||||
|
||||
// ReadFrom is where the data will be read from.
|
||||
ReadFrom io.Reader
|
||||
|
||||
// If true then we need to add gzip or bzip reader.
|
||||
// to extract the csv.
|
||||
Compressed string
|
||||
|
||||
// SQL expression meant to be evaluated.
|
||||
Expression string
|
||||
|
||||
// What the outputted CSV will be delimited by .
|
||||
OutputFieldDelimiter string
|
||||
|
||||
// Size of incoming object
|
||||
StreamSize int64
|
||||
|
||||
// Whether Header is "USE" or another
|
||||
HeaderOpt bool
|
||||
}
|
||||
|
||||
// NewInput sets up a new Input, the first row is read when this is run.
|
||||
// If there is a problem with reading the first row, the error is returned.
|
||||
// Otherwise, the returned reader can be reliably consumed with ReadRecord()
|
||||
// until ReadRecord() returns nil.
|
||||
func NewInput(opts *Options) (*Input, error) {
|
||||
myReader := opts.ReadFrom
|
||||
var tempBytesScanned int64
|
||||
tempBytesScanned = 0
|
||||
if opts.Compressed == "GZIP" {
|
||||
tempBytesScanned = opts.StreamSize
|
||||
var err error
|
||||
if myReader, err = gzip.NewReader(opts.ReadFrom); err != nil {
|
||||
return nil, ErrTruncatedInput
|
||||
}
|
||||
} else if opts.Compressed == "BZIP2" {
|
||||
tempBytesScanned = opts.StreamSize
|
||||
myReader = bzip2.NewReader(opts.ReadFrom)
|
||||
}
|
||||
|
||||
progress := &statInfo{
|
||||
BytesScanned: tempBytesScanned,
|
||||
BytesProcessed: 0,
|
||||
BytesReturned: 0,
|
||||
}
|
||||
reader := &Input{
|
||||
options: opts,
|
||||
reader: csv.NewReader(myReader),
|
||||
stats: progress,
|
||||
}
|
||||
reader.firstRow = nil
|
||||
|
||||
reader.reader.FieldsPerRecord = -1
|
||||
if reader.options.FieldDelimiter != "" {
|
||||
reader.reader.Comma = rune(reader.options.FieldDelimiter[0])
|
||||
}
|
||||
|
||||
if reader.options.Comments != "" {
|
||||
reader.reader.Comment = rune(reader.options.Comments[0])
|
||||
}
|
||||
|
||||
// QuoteCharacter - " (defaulted currently)
|
||||
reader.reader.LazyQuotes = true
|
||||
|
||||
if err := reader.readHeader(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// ReadRecord reads a single record from the . Always returns successfully.
|
||||
// If the record is empty, an empty []string is returned.
|
||||
// Record expand to match the current row size, adding blank fields as needed.
|
||||
// Records never return less then the number of fields in the first row.
|
||||
// Returns nil on EOF
|
||||
// In the event of a parse error due to an invalid record, it is logged, and
|
||||
// an empty []string is returned with the number of fields in the first row,
|
||||
// as if the record were empty.
|
||||
//
|
||||
// In general, this is a very tolerant of problems reader.
|
||||
func (reader *Input) ReadRecord() []string {
|
||||
var row []string
|
||||
var fileErr error
|
||||
|
||||
if reader.firstRow != nil {
|
||||
row = reader.firstRow
|
||||
reader.firstRow = nil
|
||||
return row
|
||||
}
|
||||
|
||||
row, fileErr = reader.reader.Read()
|
||||
emptysToAppend := reader.minOutputLength - len(row)
|
||||
if fileErr == io.EOF || fileErr == io.ErrClosedPipe {
|
||||
return nil
|
||||
} else if _, ok := fileErr.(*csv.ParseError); ok {
|
||||
emptysToAppend = reader.minOutputLength
|
||||
}
|
||||
|
||||
if emptysToAppend > 0 {
|
||||
for counter := 0; counter < emptysToAppend; counter++ {
|
||||
row = append(row, "")
|
||||
}
|
||||
}
|
||||
|
||||
return row
|
||||
}
|
||||
|
||||
// convertMySQL Replaces double quote escape for column names with backtick for
|
||||
// the MySQL parser
|
||||
func convertMySQL(random string) string {
|
||||
return strings.Replace(random, "\"", "`", len(random))
|
||||
}
|
||||
|
||||
// readHeader reads the header into the header variable if the header is present
|
||||
// as the first row of the csv
|
||||
func (reader *Input) readHeader() error {
|
||||
var readErr error
|
||||
if reader.options.HasHeader {
|
||||
reader.firstRow, readErr = reader.reader.Read()
|
||||
if readErr != nil {
|
||||
return ErrCSVParsingError
|
||||
}
|
||||
reader.header = reader.firstRow
|
||||
reader.firstRow = nil
|
||||
reader.minOutputLength = len(reader.header)
|
||||
} else {
|
||||
reader.firstRow, readErr = reader.reader.Read()
|
||||
reader.header = make([]string, len(reader.firstRow))
|
||||
for i := 0; i < reader.minOutputLength; i++ {
|
||||
reader.header[i] = strconv.Itoa(i)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createStatXML is the function which does the marshaling from the stat
|
||||
// structs into XML so that the progress and stat message can be sent
|
||||
func (reader *Input) createStatXML() (string, error) {
|
||||
if reader.options.Compressed == "NONE" {
|
||||
reader.stats.BytesProcessed = reader.options.StreamSize
|
||||
reader.stats.BytesScanned = reader.stats.BytesProcessed
|
||||
}
|
||||
statXML := stats{
|
||||
BytesScanned: reader.stats.BytesScanned,
|
||||
BytesProcessed: reader.stats.BytesProcessed,
|
||||
BytesReturned: reader.stats.BytesReturned,
|
||||
Xmlns: "",
|
||||
}
|
||||
out, err := xml.Marshal(statXML)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return xml.Header + string(out), nil
|
||||
}
|
||||
|
||||
// createProgressXML is the function which does the marshaling from the progress structs into XML so that the progress and stat message can be sent
|
||||
func (reader *Input) createProgressXML() (string, error) {
|
||||
if reader.options.HasHeader {
|
||||
reader.stats.BytesProcessed += processSize(reader.header)
|
||||
}
|
||||
if !(reader.options.Compressed != "NONE") {
|
||||
reader.stats.BytesScanned = reader.stats.BytesProcessed
|
||||
}
|
||||
progressXML := &progress{
|
||||
BytesScanned: reader.stats.BytesScanned,
|
||||
BytesProcessed: reader.stats.BytesProcessed,
|
||||
BytesReturned: reader.stats.BytesReturned,
|
||||
Xmlns: "",
|
||||
}
|
||||
out, err := xml.Marshal(progressXML)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return xml.Header + string(out), nil
|
||||
}
|
||||
|
||||
// Header returns the header of the reader. Either the first row if a header
|
||||
// set in the options, or c#, where # is the column number, starting with 0.
|
||||
func (reader *Input) Header() []string {
|
||||
return reader.header
|
||||
}
|
||||
|
||||
// Row is a Struct for keeping track of key aspects of a row.
|
||||
type Row struct {
|
||||
record string
|
||||
err error
|
||||
}
|
||||
|
||||
// Execute is the function where all the blocking occurs, It writes to the HTTP
|
||||
// response writer in a streaming fashion so that the client can actively use
|
||||
// the results before the query is finally finished executing. The
|
||||
func (reader *Input) Execute(writer io.Writer) error {
|
||||
myRow := make(chan *Row)
|
||||
curBuf := bytes.NewBuffer(make([]byte, 1000000))
|
||||
curBuf.Reset()
|
||||
progressTicker := time.NewTicker(progressTime)
|
||||
continuationTimer := time.NewTimer(continuationTime)
|
||||
defer progressTicker.Stop()
|
||||
defer continuationTimer.Stop()
|
||||
go reader.runSelectParser(convertMySQL(reader.options.Expression), myRow)
|
||||
for {
|
||||
select {
|
||||
case row, ok := <-myRow:
|
||||
if ok && row.err != nil {
|
||||
errorMessage := reader.writeErrorMessage(row.err, curBuf)
|
||||
_, err := errorMessage.WriteTo(writer)
|
||||
flusher, okFlush := writer.(http.Flusher)
|
||||
if okFlush {
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
curBuf.Reset()
|
||||
close(myRow)
|
||||
return nil
|
||||
} else if ok {
|
||||
message := reader.writeRecordMessage(row.record, curBuf)
|
||||
_, err := message.WriteTo(writer)
|
||||
flusher, okFlush := writer.(http.Flusher)
|
||||
if okFlush {
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
curBuf.Reset()
|
||||
reader.stats.BytesReturned += int64(len(row.record))
|
||||
if !continuationTimer.Stop() {
|
||||
<-continuationTimer.C
|
||||
}
|
||||
continuationTimer.Reset(continuationTime)
|
||||
} else if !ok {
|
||||
statPayload, err := reader.createStatXML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
statMessage := reader.writeStatMessage(statPayload, curBuf)
|
||||
_, err = statMessage.WriteTo(writer)
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
curBuf.Reset()
|
||||
message := reader.writeEndMessage(curBuf)
|
||||
_, err = message.WriteTo(writer)
|
||||
flusher, ok = writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
case <-progressTicker.C:
|
||||
progressPayload, err := reader.createProgressXML()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
progressMessage := reader.writeProgressMessage(progressPayload, curBuf)
|
||||
_, err = progressMessage.WriteTo(writer)
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
curBuf.Reset()
|
||||
case <-continuationTimer.C:
|
||||
message := reader.writeContinuationMessage(curBuf)
|
||||
_, err := message.WriteTo(writer)
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
curBuf.Reset()
|
||||
continuationTimer.Reset(continuationTime)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,460 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2018 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
// DO NOT EDIT THIS PACKAGE DIRECTLY: This follows the protocol defined by
|
||||
// AmazonS3 found at
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
// Consult the Spec before making direct edits.
|
||||
|
||||
package s3select
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
)
|
||||
|
||||
// Record Headers
|
||||
// -11 -event type - 7 - 7 "Records"
|
||||
// -13 -content-type -7 -24 "application/octet-stream"
|
||||
// -13 -message-type -7 5 "event"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var recordHeaders []byte
|
||||
|
||||
// End Headers
|
||||
// -13 -message-type -7 -5 "event"
|
||||
// -11 -:event-type -7 -3 "End"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var endHeaders []byte
|
||||
|
||||
// Continuation Headers
|
||||
// -13 -message-type -7 -5 "event"
|
||||
// -11 -:event-type -7 -4 "Cont"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var contHeaders []byte
|
||||
|
||||
// Stat Headers
|
||||
// -11 -event type - 7 - 5 "Stat" -20
|
||||
// -13 -content-type -7 -8 "text/xml" -25
|
||||
// -13 -message-type -7 -5 "event" -22
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var statHeaders []byte
|
||||
|
||||
// Progress Headers
|
||||
// -11 -event type - 7 - 8 "Progress" -23
|
||||
// -13 -content-type -7 -8 "text/xml" -25
|
||||
// -13 -message-type -7 -5 "event" -22
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var progressHeaders []byte
|
||||
|
||||
// The length of the nonvariable portion of the ErrHeaders
|
||||
// The below are the specifications of the header for a "error" event
|
||||
// -11 -error-code - 7 - DEFINED "DEFINED"
|
||||
// -14 -error-message -7 -DEFINED "DEFINED"
|
||||
// -13 -message-type -7 -5 "error"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var errHdrLen int
|
||||
|
||||
func init() {
|
||||
recordHeaders = writeRecordHeader()
|
||||
endHeaders = writeEndHeader()
|
||||
contHeaders = writeContHeader()
|
||||
statHeaders = writeStatHeader()
|
||||
progressHeaders = writeProgressHeader()
|
||||
errHdrLen = 55
|
||||
|
||||
}
|
||||
|
||||
// encodeString encodes a string in a []byte, lenBytes is the number of bytes
|
||||
// used to encode the length of the string.
|
||||
func encodeHeaderStringValue(s string) []byte {
|
||||
n := uint16(len(s))
|
||||
lenSlice := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(lenSlice[0:], n)
|
||||
return append(lenSlice, []byte(s)...)
|
||||
}
|
||||
func encodeHeaderStringName(s string) []byte {
|
||||
lenSlice := make([]byte, 1)
|
||||
lenSlice[0] = byte(len(s))
|
||||
return append(lenSlice, []byte(s)...)
|
||||
}
|
||||
|
||||
// encodeNumber encodes a number in a []byte, lenBytes is the number of bytes
|
||||
// used to encode the length of the string.
|
||||
func encodeNumber(n byte, lenBytes int) []byte {
|
||||
lenSlice := make([]byte, lenBytes)
|
||||
lenSlice[0] = n
|
||||
return lenSlice
|
||||
}
|
||||
|
||||
// writePayloadSize writes the 4byte payload size portion of the protocol.
|
||||
func writePayloadSize(payloadSize int, headerLength int) []byte {
|
||||
totalByteLen := make([]byte, 4)
|
||||
totalMsgLen := uint32(payloadSize + headerLength + 16)
|
||||
binary.BigEndian.PutUint32(totalByteLen, totalMsgLen)
|
||||
return totalByteLen
|
||||
}
|
||||
|
||||
// writeHeaderSize writes the 4byte header size portion of the protocol.
|
||||
func writeHeaderSize(headerLength int) []byte {
|
||||
totalHeaderLen := make([]byte, 4)
|
||||
totalLen := uint32(headerLength)
|
||||
binary.BigEndian.PutUint32(totalHeaderLen, totalLen)
|
||||
return totalHeaderLen
|
||||
}
|
||||
|
||||
// writeCRC writes the CRC for both the prelude and and the end of the protocol.
|
||||
func writeCRC(myBuffer []byte) []byte {
|
||||
// Calculate the CRC here:
|
||||
myCRC := make([]byte, 4)
|
||||
cksum := crc32.ChecksumIEEE(myBuffer)
|
||||
binary.BigEndian.PutUint32(myCRC, cksum)
|
||||
return myCRC
|
||||
}
|
||||
|
||||
// writePayload writes the Payload for those protocols which the Payload is
|
||||
// necessary.
|
||||
func writePayload(myPayload string) []byte {
|
||||
convertedPayload := []byte(myPayload)
|
||||
payloadStore := make([]byte, len(convertedPayload))
|
||||
copy(payloadStore[0:], myPayload)
|
||||
return payloadStore
|
||||
}
|
||||
|
||||
// writeRecordHeader is a function which writes the headers for the continuation
|
||||
// Message
|
||||
func writeRecordHeader() []byte {
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var currentMessage = &bytes.Buffer{}
|
||||
// 11 -event type - 7 - 7 "Records"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":event-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("Records"))
|
||||
// Creation of the Header for Content-Type // 13 -content-type -7 -24
|
||||
// "application/octet-stream"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":content-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("application/octet-stream"))
|
||||
// Creation of the Header for message-type 13 -message-type -7 5 "event"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":message-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("event"))
|
||||
return currentMessage.Bytes()
|
||||
}
|
||||
|
||||
// writeEndHeader is a function which writes the headers for the continuation
|
||||
// Message
|
||||
func writeEndHeader() []byte {
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var currentMessage = &bytes.Buffer{}
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":event-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("End"))
|
||||
|
||||
// Creation of the Header for message-type 13 -message-type -7 5 "event"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":message-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("event"))
|
||||
return currentMessage.Bytes()
|
||||
}
|
||||
|
||||
// writeContHeader is a function which writes the headers for the continuation
|
||||
// Message
|
||||
func writeContHeader() []byte {
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var currentMessage = &bytes.Buffer{}
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":event-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("Cont"))
|
||||
|
||||
// Creation of the Header for message-type 13 -message-type -7 5 "event"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":message-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("event"))
|
||||
return currentMessage.Bytes()
|
||||
|
||||
}
|
||||
|
||||
// writeStatHeader is a function which writes the headers for the Stat
|
||||
// Message
|
||||
func writeStatHeader() []byte {
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var currentMessage = &bytes.Buffer{}
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":event-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("Stats"))
|
||||
// Creation of the Header for Content-Type // 13 -content-type -7 -8
|
||||
// "text/xml"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":content-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("text/xml"))
|
||||
|
||||
// Creation of the Header for message-type 13 -message-type -7 5 "event"
|
||||
currentMessage.Write(encodeHeaderStringName(":message-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("event"))
|
||||
return currentMessage.Bytes()
|
||||
|
||||
}
|
||||
|
||||
// writeProgressHeader is a function which writes the headers for the Progress
|
||||
// Message
|
||||
func writeProgressHeader() []byte {
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
var currentMessage = &bytes.Buffer{}
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":event-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("Progress"))
|
||||
// Creation of the Header for Content-Type // 13 -content-type -7 -8
|
||||
// "text/xml"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":content-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("text/xml"))
|
||||
|
||||
// Creation of the Header for message-type 13 -message-type -7 5 "event"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":message-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("event"))
|
||||
return currentMessage.Bytes()
|
||||
|
||||
}
|
||||
|
||||
// writeRecordMessage is the function which constructs the binary message for a
|
||||
// record message to be sent.
|
||||
func (csvOutput *Input) writeRecordMessage(payload string, currentMessage *bytes.Buffer) *bytes.Buffer {
|
||||
// The below are the specifications of the header for a "record" event
|
||||
// 11 -event type - 7 - 7 "Records"
|
||||
// 13 -content-type -7 -24 "application/octet-stream"
|
||||
// 13 -message-type -7 5 "event"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
headerLen := len(recordHeaders)
|
||||
// Writes the total size of the message.
|
||||
currentMessage.Write(writePayloadSize(len(payload), headerLen))
|
||||
// Writes the total size of the header.
|
||||
currentMessage.Write(writeHeaderSize(headerLen))
|
||||
// Writes the CRC of the Prelude
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
currentMessage.Write(recordHeaders)
|
||||
|
||||
// This part is where the payload is written, this will be only one row, since
|
||||
// we're sending one message at a types
|
||||
currentMessage.Write(writePayload(payload))
|
||||
|
||||
// Now we do a CRC check on the entire messages
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
return currentMessage
|
||||
|
||||
}
|
||||
|
||||
// writeContinuationMessage is the function which constructs the binary message
|
||||
// for a continuation message to be sent.
|
||||
func (csvOutput *Input) writeContinuationMessage(currentMessage *bytes.Buffer) *bytes.Buffer {
|
||||
// 11 -event type - 7 - 4 "Cont"
|
||||
// 13 -message-type -7 5 "event"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
headerLen := len(contHeaders)
|
||||
currentMessage.Write(writePayloadSize(0, headerLen))
|
||||
|
||||
currentMessage.Write(writeHeaderSize(headerLen))
|
||||
|
||||
// Calculate the Prelude CRC here:
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
|
||||
currentMessage.Write(contHeaders)
|
||||
|
||||
//Now we do a CRC check on the entire messages
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
return currentMessage
|
||||
|
||||
}
|
||||
|
||||
// writeEndMessage is the function which constructs the binary message
|
||||
// for a end message to be sent.
|
||||
func (csvOutput *Input) writeEndMessage(currentMessage *bytes.Buffer) *bytes.Buffer {
|
||||
// 11 -event type - 7 - 3 "End"
|
||||
// 13 -message-type -7 5 "event"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
headerLen := len(endHeaders)
|
||||
currentMessage.Write(writePayloadSize(0, headerLen))
|
||||
|
||||
currentMessage.Write(writeHeaderSize(headerLen))
|
||||
|
||||
//Calculate the Prelude CRC here:
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
|
||||
currentMessage.Write(endHeaders)
|
||||
|
||||
// Now we do a CRC check on the entire messages
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
return currentMessage
|
||||
|
||||
}
|
||||
|
||||
// writeStateMessage is the function which constructs the binary message for a
|
||||
// state message to be sent.
|
||||
func (csvOutput *Input) writeStatMessage(payload string, currentMessage *bytes.Buffer) *bytes.Buffer {
|
||||
// 11 -event type - 7 - 5 "Stat" 20
|
||||
// 13 -content-type -7 -8 "text/xml" 25
|
||||
// 13 -message-type -7 5 "event" 22
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
headerLen := len(statHeaders)
|
||||
|
||||
currentMessage.Write(writePayloadSize(len(payload), headerLen))
|
||||
|
||||
currentMessage.Write(writeHeaderSize(headerLen))
|
||||
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
|
||||
currentMessage.Write(statHeaders)
|
||||
|
||||
// This part is where the payload is written, this will be only one row, since
|
||||
// we're sending one message at a types
|
||||
currentMessage.Write(writePayload(payload))
|
||||
|
||||
// Now we do a CRC check on the entire messages
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
return currentMessage
|
||||
|
||||
}
|
||||
|
||||
// writeProgressMessage is the function which constructs the binary message for
|
||||
// a progress message to be sent.
|
||||
func (csvOutput *Input) writeProgressMessage(payload string, currentMessage *bytes.Buffer) *bytes.Buffer {
|
||||
// The below are the specifications of the header for a "Progress" event
|
||||
// 11 -event type - 7 - 8 "Progress" 23
|
||||
// 13 -content-type -7 -8 "text/xml" 25
|
||||
// 13 -message-type -7 5 "event" 22
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
headerLen := len(progressHeaders)
|
||||
|
||||
currentMessage.Write(writePayloadSize(len(payload), headerLen))
|
||||
|
||||
currentMessage.Write(writeHeaderSize(headerLen))
|
||||
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
|
||||
currentMessage.Write(progressHeaders)
|
||||
|
||||
// This part is where the payload is written, this will be only one row, since
|
||||
// we're sending one message at a types
|
||||
currentMessage.Write(writePayload(payload))
|
||||
|
||||
// Now we do a CRC check on the entire messages
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
return currentMessage
|
||||
|
||||
}
|
||||
|
||||
// writeErrorMessage is the function which constructs the binary message for a
|
||||
// error message to be sent.
|
||||
func (csvOutput *Input) writeErrorMessage(errorMessage error, currentMessage *bytes.Buffer) *bytes.Buffer {
|
||||
|
||||
// The below are the specifications of the header for a "error" event
|
||||
// 11 -error-code - 7 - DEFINED "DEFINED"
|
||||
// 14 -error-message -7 -DEFINED "DEFINED"
|
||||
// 13 -message-type -7 5 "error"
|
||||
// This is predefined from AMZ protocol found here:
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectSELECTContent.html
|
||||
sizeOfErrorCode := len(errorCodeResponse[errorMessage])
|
||||
sizeOfErrorMessage := len(errorMessage.Error())
|
||||
headerLen := errHdrLen + sizeOfErrorCode + sizeOfErrorMessage
|
||||
|
||||
currentMessage.Write(writePayloadSize(0, headerLen))
|
||||
|
||||
currentMessage.Write(writeHeaderSize(headerLen))
|
||||
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":error-code"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue(errorCodeResponse[errorMessage]))
|
||||
|
||||
// 14 -error-message -7 -DEFINED "DEFINED"
|
||||
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":error-message"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue(errorMessage.Error()))
|
||||
// Creation of the Header for message-type 13 -message-type -7 5 "error"
|
||||
// header name
|
||||
currentMessage.Write(encodeHeaderStringName(":message-type"))
|
||||
// header type
|
||||
currentMessage.Write(encodeNumber(7, 1))
|
||||
// header value and header value length
|
||||
currentMessage.Write(encodeHeaderStringValue("error"))
|
||||
|
||||
// Now we do a CRC check on the entire messages
|
||||
currentMessage.Write(writeCRC(currentMessage.Bytes()))
|
||||
return currentMessage
|
||||
|
||||
}
|
|
@ -0,0 +1,415 @@
|
|||
/*
|
||||
* Minio Cloud Storage, (C) 2018 Minio, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package s3select
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
// SelectFuncs contains the relevant values from the parser for S3 Select
|
||||
// Functions
|
||||
type SelectFuncs struct {
|
||||
funcExpr []*sqlparser.FuncExpr
|
||||
index []int
|
||||
}
|
||||
|
||||
// RunSqlParser allows us to easily bundle all the functions from above and run
|
||||
// them in the appropriate order.
|
||||
func (reader *Input) runSelectParser(selectExpression string, myRow chan *Row) {
|
||||
reqCols, alias, myLimit, whereClause, aggFunctionNames, myFuncs, myErr := reader.ParseSelect(selectExpression)
|
||||
if myErr != nil {
|
||||
rowStruct := &Row{
|
||||
err: myErr,
|
||||
}
|
||||
myRow <- rowStruct
|
||||
return
|
||||
}
|
||||
reader.processSelectReq(reqCols, alias, whereClause, myLimit, aggFunctionNames, myRow, myFuncs)
|
||||
}
|
||||
|
||||
// ParseSelect parses the SELECT expression, and effectively tokenizes it into
|
||||
// its separate parts. It returns the requested column names,alias,limit of
|
||||
// records, and the where clause.
|
||||
func (reader *Input) ParseSelect(sqlInput string) ([]string, string, int, interface{}, []string, *SelectFuncs, error) {
|
||||
// return columnNames, alias, limitOfRecords, whereclause,coalStore, nil
|
||||
|
||||
stmt, err := sqlparser.Parse(sqlInput)
|
||||
var whereClause interface{}
|
||||
var alias string
|
||||
var limit int
|
||||
myFuncs := &SelectFuncs{}
|
||||
// TODO Maybe can parse their errors a bit to return some more of the s3 errors
|
||||
if err != nil {
|
||||
return nil, "", 0, nil, nil, nil, ErrLexerInvalidChar
|
||||
}
|
||||
switch stmt := stmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
// evaluates the where clause
|
||||
functionNames := make([]string, len(stmt.SelectExprs))
|
||||
columnNames := make([]string, len(stmt.SelectExprs))
|
||||
|
||||
if stmt.Where != nil {
|
||||
switch expr := stmt.Where.Expr.(type) {
|
||||
default:
|
||||
whereClause = expr
|
||||
case *sqlparser.ComparisonExpr:
|
||||
whereClause = expr
|
||||
}
|
||||
}
|
||||
if stmt.SelectExprs != nil {
|
||||
for i := 0; i < len(stmt.SelectExprs); i++ {
|
||||
switch expr := stmt.SelectExprs[i].(type) {
|
||||
case *sqlparser.StarExpr:
|
||||
columnNames[0] = "*"
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch smallerexpr := expr.Expr.(type) {
|
||||
case *sqlparser.FuncExpr:
|
||||
if smallerexpr.IsAggregate() {
|
||||
functionNames[i] = smallerexpr.Name.CompliantName()
|
||||
// Will return function name
|
||||
// Case to deal with if we have functions and not an asterix
|
||||
switch tempagg := smallerexpr.Exprs[0].(type) {
|
||||
case *sqlparser.StarExpr:
|
||||
columnNames[0] = "*"
|
||||
if smallerexpr.Name.CompliantName() != "count" {
|
||||
return nil, "", 0, nil, nil, nil, ErrParseUnsupportedCallWithStar
|
||||
}
|
||||
case *sqlparser.AliasedExpr:
|
||||
switch col := tempagg.Expr.(type) {
|
||||
case *sqlparser.BinaryExpr:
|
||||
return nil, "", 0, nil, nil, nil, ErrParseNonUnaryAgregateFunctionCall
|
||||
case *sqlparser.ColName:
|
||||
columnNames[i] = col.Name.CompliantName()
|
||||
}
|
||||
}
|
||||
// Case to deal with if COALESCE was used..
|
||||
} else if supportedFunc(smallerexpr.Name.CompliantName()) {
|
||||
if myFuncs.funcExpr == nil {
|
||||
myFuncs.funcExpr = make([]*sqlparser.FuncExpr, len(stmt.SelectExprs))
|
||||
myFuncs.index = make([]int, len(stmt.SelectExprs))
|
||||
}
|
||||
myFuncs.funcExpr[i] = smallerexpr
|
||||
myFuncs.index[i] = i
|
||||
} else {
|
||||
return nil, "", 0, nil, nil, nil, ErrUnsupportedSQLOperation
|
||||
}
|
||||
case *sqlparser.ColName:
|
||||
columnNames[i] = smallerexpr.Name.CompliantName()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This code retrieves the alias and makes sure it is set to the correct
|
||||
// value, if not it sets it to the tablename
|
||||
if (stmt.From) != nil {
|
||||
for i := 0; i < len(stmt.From); i++ {
|
||||
switch smallerexpr := stmt.From[i].(type) {
|
||||
case *sqlparser.JoinTableExpr:
|
||||
return nil, "", 0, nil, nil, nil, ErrParseMalformedJoin
|
||||
case *sqlparser.AliasedTableExpr:
|
||||
alias = smallerexpr.As.CompliantName()
|
||||
if alias == "" {
|
||||
alias = sqlparser.GetTableName(smallerexpr.Expr).CompliantName()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if stmt.Limit != nil {
|
||||
switch expr := stmt.Limit.Rowcount.(type) {
|
||||
case *sqlparser.SQLVal:
|
||||
// The Value of how many rows we're going to limit by
|
||||
limit, _ = strconv.Atoi(string(expr.Val[:]))
|
||||
}
|
||||
}
|
||||
if stmt.GroupBy != nil {
|
||||
return nil, "", 0, nil, nil, nil, ErrParseUnsupportedLiteralsGroupBy
|
||||
}
|
||||
if stmt.OrderBy != nil {
|
||||
return nil, "", 0, nil, nil, nil, ErrParseUnsupportedToken
|
||||
}
|
||||
if err := reader.parseErrs(columnNames, whereClause, alias, myFuncs); err != nil {
|
||||
return nil, "", 0, nil, nil, nil, err
|
||||
}
|
||||
return columnNames, alias, limit, whereClause, functionNames, myFuncs, nil
|
||||
}
|
||||
return nil, "", 0, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
// This is the main function, It goes row by row and for records which validate
|
||||
// the where clause it currently prints the appropriate row given the requested
|
||||
// columns.
|
||||
func (reader *Input) processSelectReq(reqColNames []string, alias string, whereClause interface{}, limitOfRecords int, functionNames []string, myRow chan *Row, myFunc *SelectFuncs) {
|
||||
counter := -1
|
||||
filtrCount := 0
|
||||
functionFlag := false
|
||||
// My values is used to store our aggregation values if we need to store them.
|
||||
myAggVals := make([]float64, len(reqColNames))
|
||||
var columns []string
|
||||
// LowercasecolumnsMap is used in accordance with hasDuplicates so that we can
|
||||
// raise the error "Ambigious" if a case insensitive column is provided and we
|
||||
// have multiple matches.
|
||||
lowercaseColumnsMap := make(map[string]int)
|
||||
hasDuplicates := make(map[string]bool)
|
||||
// ColumnsMap stores our columns and their index.
|
||||
columnsMap := make(map[string]int)
|
||||
if limitOfRecords == 0 {
|
||||
limitOfRecords = math.MaxInt64
|
||||
}
|
||||
|
||||
for {
|
||||
record := reader.ReadRecord()
|
||||
reader.stats.BytesProcessed += processSize(record)
|
||||
if record == nil {
|
||||
if functionFlag {
|
||||
rowStruct := &Row{
|
||||
record: reader.aggFuncToStr(myAggVals) + "\n",
|
||||
}
|
||||
myRow <- rowStruct
|
||||
}
|
||||
close(myRow)
|
||||
return
|
||||
}
|
||||
if counter == -1 && reader.options.HeaderOpt && len(reader.header) > 0 {
|
||||
columns = reader.Header()
|
||||
myErr := checkForDuplicates(columns, columnsMap, hasDuplicates, lowercaseColumnsMap)
|
||||
if myErr != nil {
|
||||
rowStruct := &Row{
|
||||
err: myErr,
|
||||
}
|
||||
myRow <- rowStruct
|
||||
return
|
||||
}
|
||||
} else if counter == -1 && len(reader.header) > 0 {
|
||||
columns = reader.Header()
|
||||
}
|
||||
// When we have reached our limit, on what the user specified as the number
|
||||
// of rows they wanted, we terminate our interpreter.
|
||||
if filtrCount == limitOfRecords && limitOfRecords != 0 {
|
||||
close(myRow)
|
||||
return
|
||||
}
|
||||
// The call to the where function clause,ensures that the rows we print match our where clause.
|
||||
condition, myErr := matchesMyWhereClause(record, columnsMap, alias, whereClause)
|
||||
if myErr != nil {
|
||||
rowStruct := &Row{
|
||||
err: myErr,
|
||||
}
|
||||
myRow <- rowStruct
|
||||
return
|
||||
}
|
||||
if condition {
|
||||
// if its an asterix we just print everything in the row
|
||||
if reqColNames[0] == "*" && functionNames[0] == "" {
|
||||
rowStruct := &Row{
|
||||
record: reader.printAsterix(record) + "\n",
|
||||
}
|
||||
myRow <- rowStruct
|
||||
} else if alias != "" {
|
||||
// This is for dealing with the case of if we have to deal with a
|
||||
// request for a column with an index e.g A_1.
|
||||
if representsInt(reqColNames[0]) {
|
||||
// This checks whether any aggregation function was called as now we
|
||||
// no longer will go through printing each row, and only print at the
|
||||
// end
|
||||
if len(functionNames) > 0 && functionNames[0] != "" {
|
||||
functionFlag = true
|
||||
aggregationFunctions(counter, filtrCount, myAggVals, columnsMap, reqColNames, functionNames, record)
|
||||
} else {
|
||||
// The code below finds the appropriate columns of the row given the
|
||||
// indicies provided in the SQL request and utilizes the map to
|
||||
// retrieve the correct part of the row.
|
||||
myQueryRow, myErr := reader.processColNameIndex(record, reqColNames, columns)
|
||||
if myErr != nil {
|
||||
rowStruct := &Row{
|
||||
err: myErr,
|
||||
}
|
||||
myRow <- rowStruct
|
||||
return
|
||||
}
|
||||
rowStruct := &Row{
|
||||
record: myQueryRow + "\n",
|
||||
}
|
||||
myRow <- rowStruct
|
||||
}
|
||||
} else {
|
||||
// This code does aggregation if we were provided column names in the
|
||||
// form of acutal names rather an indices.
|
||||
if len(functionNames) > 0 && functionNames[0] != "" {
|
||||
functionFlag = true
|
||||
aggregationFunctions(counter, filtrCount, myAggVals, columnsMap, reqColNames, functionNames, record)
|
||||
} else {
|
||||
// This code prints the appropriate part of the row given the filter
|
||||
// and select request, if the select request was based on column
|
||||
// names rather than indices.
|
||||
myQueryRow, myErr := reader.processColNameLiteral(record, reqColNames, columns, columnsMap, myFunc)
|
||||
if myErr != nil {
|
||||
rowStruct := &Row{
|
||||
err: myErr,
|
||||
}
|
||||
myRow <- rowStruct
|
||||
return
|
||||
}
|
||||
rowStruct := &Row{
|
||||
record: myQueryRow + "\n",
|
||||
}
|
||||
myRow <- rowStruct
|
||||
}
|
||||
}
|
||||
}
|
||||
filtrCount++
|
||||
}
|
||||
counter++
|
||||
}
|
||||
}
|
||||
|
||||
// printAsterix helps to print out the entire row if an asterix is used.
|
||||
func (reader *Input) printAsterix(record []string) string {
|
||||
myRow := record[0]
|
||||
for i := 1; i < len(record); i++ {
|
||||
myRow = myRow + reader.options.OutputFieldDelimiter + record[i]
|
||||
}
|
||||
return myRow
|
||||
}
|
||||
|
||||
// processColumnNames is a function which allows for cleaning of column names.
|
||||
func (reader *Input) processColumnNames(reqColNames []string, alias string) error {
|
||||
for i := 0; i < len(reqColNames); i++ {
|
||||
// The code below basically cleans the column name of its alias and other
|
||||
// syntax, so that we can extract its pure name.
|
||||
reqColNames[i] = cleanCol(reqColNames[i], alias)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processColNameIndex is the function which creates the row for an index based
|
||||
// query.
|
||||
func (reader *Input) processColNameIndex(record []string, reqColNames []string, columns []string) (string, error) {
|
||||
myRow := ""
|
||||
for i := 0; i < len(reqColNames); i++ {
|
||||
// COALESCE AND NULLIF do not support index based access.
|
||||
if reqColNames[0] == "0" {
|
||||
return "", ErrInvalidColumnIndex
|
||||
}
|
||||
// Subtract 1 because AWS Indexing is not 0 based, it starts at 1.
|
||||
mytempindex, err := strconv.Atoi(reqColNames[i])
|
||||
mytempindex = mytempindex - 1
|
||||
if mytempindex > len(columns) {
|
||||
return "", ErrInvalidColumnIndex
|
||||
}
|
||||
myRow = writeRow(myRow, record[mytempindex], reader.options.OutputFieldDelimiter, len(reqColNames))
|
||||
if err != nil {
|
||||
return "", ErrMissingHeaders
|
||||
}
|
||||
}
|
||||
if len(myRow) > 1000000 {
|
||||
return "", ErrOverMaxRecordSize
|
||||
}
|
||||
if strings.Count(myRow, reader.options.OutputFieldDelimiter) != len(reqColNames)-1 {
|
||||
myRow = qualityCheck(myRow, len(reqColNames)-1-strings.Count(myRow, reader.options.OutputFieldDelimiter), reader.options.OutputFieldDelimiter)
|
||||
}
|
||||
return myRow, nil
|
||||
}
|
||||
|
||||
// processColNameLiteral is the function which creates the row for an name based
|
||||
// query.
|
||||
func (reader *Input) processColNameLiteral(record []string, reqColNames []string, columns []string, columnsMap map[string]int, myFunc *SelectFuncs) (string, error) {
|
||||
myRow := ""
|
||||
for i := 0; i < len(reqColNames); i++ {
|
||||
// this is the case to deal with COALESCE.
|
||||
if reqColNames[i] == "" && isValidFunc(myFunc.index, i) {
|
||||
myVal := evaluateFuncExpr(myFunc.funcExpr[i], "", record, columnsMap)
|
||||
myRow = writeRow(myRow, myVal, reader.options.OutputFieldDelimiter, len(reqColNames))
|
||||
continue
|
||||
}
|
||||
myTempIndex, notFound := columnsMap[trimQuotes(reqColNames[i])]
|
||||
if !notFound {
|
||||
return "", ErrMissingHeaders
|
||||
}
|
||||
myRow = writeRow(myRow, record[myTempIndex], reader.options.OutputFieldDelimiter, len(reqColNames))
|
||||
}
|
||||
if len(myRow) > 1000000 {
|
||||
return "", ErrOverMaxRecordSize
|
||||
}
|
||||
if strings.Count(myRow, reader.options.OutputFieldDelimiter) != len(reqColNames)-1 {
|
||||
myRow = qualityCheck(myRow, len(reqColNames)-1-strings.Count(myRow, reader.options.OutputFieldDelimiter), reader.options.OutputFieldDelimiter)
|
||||
}
|
||||
return myRow, nil
|
||||
|
||||
}
|
||||
|
||||
// aggregationFunctions is a function which performs the actual aggregation
|
||||
// methods on the given row, it uses an array defined the the main parsing
|
||||
// function to keep track of values.
|
||||
func aggregationFunctions(counter int, filtrCount int, myAggVals []float64, columnsMap map[string]int, storeReqCols []string, storeFunctions []string, record []string) error {
|
||||
for i := 0; i < len(storeFunctions); i++ {
|
||||
if storeFunctions[i] == "" {
|
||||
i++
|
||||
} else if storeFunctions[i] == "count" {
|
||||
myAggVals[i]++
|
||||
} else {
|
||||
// If column names are provided as an index it'll use this if statement instead of the else/
|
||||
var convAggFloat float64
|
||||
if representsInt(storeReqCols[i]) {
|
||||
myIndex, _ := strconv.Atoi(storeReqCols[i])
|
||||
convAggFloat, _ = strconv.ParseFloat(record[myIndex], 64)
|
||||
|
||||
} else {
|
||||
// case that the columns are in the form of named columns rather than indices.
|
||||
convAggFloat, _ = strconv.ParseFloat(record[columnsMap[trimQuotes(storeReqCols[i])]], 64)
|
||||
|
||||
}
|
||||
// This if statement is for calculating the min.
|
||||
if storeFunctions[i] == "min" {
|
||||
if counter == -1 {
|
||||
myAggVals[i] = math.MaxFloat64
|
||||
}
|
||||
if convAggFloat < myAggVals[i] {
|
||||
myAggVals[i] = convAggFloat
|
||||
}
|
||||
|
||||
} else if storeFunctions[i] == "max" {
|
||||
// This if statement is for calculating the max.
|
||||
if counter == -1 {
|
||||
myAggVals[i] = math.SmallestNonzeroFloat64
|
||||
}
|
||||
if convAggFloat > myAggVals[i] {
|
||||
myAggVals[i] = convAggFloat
|
||||
}
|
||||
|
||||
} else if storeFunctions[i] == "sum" {
|
||||
// This if statement is for calculating the sum.
|
||||
myAggVals[i] += convAggFloat
|
||||
|
||||
} else if storeFunctions[i] == "avg" {
|
||||
// This if statement is for calculating the average.
|
||||
if filtrCount == 0 {
|
||||
myAggVals[i] = convAggFloat
|
||||
} else {
|
||||
myAggVals[i] = (convAggFloat + (myAggVals[i] * float64(filtrCount))) / float64((filtrCount + 1))
|
||||
}
|
||||
} else {
|
||||
return ErrParseNonUnaryAgregateFunctionCall
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
// forwardCopy is like the built-in copy function except that it always goes
|
||||
// forward from the start, even if the dst and src overlap.
|
||||
// It is equivalent to:
|
||||
// for i := 0; i < n; i++ {
|
||||
// mem[dst+i] = mem[src+i]
|
||||
// }
|
||||
func forwardCopy(mem []byte, dst, src, n int) {
|
||||
if dst <= src {
|
||||
copy(mem[dst:dst+n], mem[src:src+n])
|
||||
return
|
||||
}
|
||||
for {
|
||||
if dst >= src+n {
|
||||
copy(mem[dst:dst+n], mem[src:src+n])
|
||||
return
|
||||
}
|
||||
// There is some forward overlap. The destination
|
||||
// will be filled with a repeated pattern of mem[src:src+k].
|
||||
// We copy one instance of the pattern here, then repeat.
|
||||
// Each time around this loop k will double.
|
||||
k := dst - src
|
||||
copy(mem[dst:dst+k], mem[src:src+k])
|
||||
n -= k
|
||||
dst += k
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
//+build !noasm
|
||||
//+build !appengine
|
||||
|
||||
// Copyright 2015, Klaus Post, see LICENSE for details.
|
||||
|
||||
package flate
|
||||
|
||||
import (
|
||||
"github.com/klauspost/cpuid"
|
||||
)
|
||||
|
||||
// crc32sse returns a hash for the first 4 bytes of the slice
|
||||
// len(a) must be >= 4.
|
||||
//go:noescape
|
||||
func crc32sse(a []byte) uint32
|
||||
|
||||
// crc32sseAll calculates hashes for each 4-byte set in a.
|
||||
// dst must be east len(a) - 4 in size.
|
||||
// The size is not checked by the assembly.
|
||||
//go:noescape
|
||||
func crc32sseAll(a []byte, dst []uint32)
|
||||
|
||||
// matchLenSSE4 returns the number of matching bytes in a and b
|
||||
// up to length 'max'. Both slices must be at least 'max'
|
||||
// bytes in size.
|
||||
//
|
||||
// TODO: drop the "SSE4" name, since it doesn't use any SSE instructions.
|
||||
//
|
||||
//go:noescape
|
||||
func matchLenSSE4(a, b []byte, max int) int
|
||||
|
||||
// histogram accumulates a histogram of b in h.
|
||||
// h must be at least 256 entries in length,
|
||||
// and must be cleared before calling this function.
|
||||
//go:noescape
|
||||
func histogram(b []byte, h []int32)
|
||||
|
||||
// Detect SSE 4.2 feature.
|
||||
func init() {
|
||||
useSSE42 = cpuid.CPU.SSE42()
|
||||
}
|
|
@ -0,0 +1,213 @@
|
|||
//+build !noasm
|
||||
//+build !appengine
|
||||
|
||||
// Copyright 2015, Klaus Post, see LICENSE for details.
|
||||
|
||||
// func crc32sse(a []byte) uint32
|
||||
TEXT ·crc32sse(SB), 4, $0
|
||||
MOVQ a+0(FP), R10
|
||||
XORQ BX, BX
|
||||
|
||||
// CRC32 dword (R10), EBX
|
||||
BYTE $0xF2; BYTE $0x41; BYTE $0x0f
|
||||
BYTE $0x38; BYTE $0xf1; BYTE $0x1a
|
||||
|
||||
MOVL BX, ret+24(FP)
|
||||
RET
|
||||
|
||||
// func crc32sseAll(a []byte, dst []uint32)
|
||||
TEXT ·crc32sseAll(SB), 4, $0
|
||||
MOVQ a+0(FP), R8 // R8: src
|
||||
MOVQ a_len+8(FP), R10 // input length
|
||||
MOVQ dst+24(FP), R9 // R9: dst
|
||||
SUBQ $4, R10
|
||||
JS end
|
||||
JZ one_crc
|
||||
MOVQ R10, R13
|
||||
SHRQ $2, R10 // len/4
|
||||
ANDQ $3, R13 // len&3
|
||||
XORQ BX, BX
|
||||
ADDQ $1, R13
|
||||
TESTQ R10, R10
|
||||
JZ rem_loop
|
||||
|
||||
crc_loop:
|
||||
MOVQ (R8), R11
|
||||
XORQ BX, BX
|
||||
XORQ DX, DX
|
||||
XORQ DI, DI
|
||||
MOVQ R11, R12
|
||||
SHRQ $8, R11
|
||||
MOVQ R12, AX
|
||||
MOVQ R11, CX
|
||||
SHRQ $16, R12
|
||||
SHRQ $16, R11
|
||||
MOVQ R12, SI
|
||||
|
||||
// CRC32 EAX, EBX
|
||||
BYTE $0xF2; BYTE $0x0f
|
||||
BYTE $0x38; BYTE $0xf1; BYTE $0xd8
|
||||
|
||||
// CRC32 ECX, EDX
|
||||
BYTE $0xF2; BYTE $0x0f
|
||||
BYTE $0x38; BYTE $0xf1; BYTE $0xd1
|
||||
|
||||
// CRC32 ESI, EDI
|
||||
BYTE $0xF2; BYTE $0x0f
|
||||
BYTE $0x38; BYTE $0xf1; BYTE $0xfe
|
||||
MOVL BX, (R9)
|
||||
MOVL DX, 4(R9)
|
||||
MOVL DI, 8(R9)
|
||||
|
||||
XORQ BX, BX
|
||||
MOVL R11, AX
|
||||
|
||||
// CRC32 EAX, EBX
|
||||
BYTE $0xF2; BYTE $0x0f
|
||||
BYTE $0x38; BYTE $0xf1; BYTE $0xd8
|
||||
MOVL BX, 12(R9)
|
||||
|
||||
ADDQ $16, R9
|
||||
ADDQ $4, R8
|
||||
XORQ BX, BX
|
||||
SUBQ $1, R10
|
||||
JNZ crc_loop
|
||||
|
||||
rem_loop:
|
||||
MOVL (R8), AX
|
||||
|
||||
// CRC32 EAX, EBX
|
||||
BYTE $0xF2; BYTE $0x0f
|
||||
BYTE $0x38; BYTE $0xf1; BYTE $0xd8
|
||||
|
||||
MOVL BX, (R9)
|
||||
ADDQ $4, R9
|
||||
ADDQ $1, R8
|
||||
XORQ BX, BX
|
||||
SUBQ $1, R13
|
||||
JNZ rem_loop
|
||||
|
||||
end:
|
||||
RET
|
||||
|
||||
one_crc:
|
||||
MOVQ $1, R13
|
||||
XORQ BX, BX
|
||||
JMP rem_loop
|
||||
|
||||
// func matchLenSSE4(a, b []byte, max int) int
|
||||
TEXT ·matchLenSSE4(SB), 4, $0
|
||||
MOVQ a_base+0(FP), SI
|
||||
MOVQ b_base+24(FP), DI
|
||||
MOVQ DI, DX
|
||||
MOVQ max+48(FP), CX
|
||||
|
||||
cmp8:
|
||||
// As long as we are 8 or more bytes before the end of max, we can load and
|
||||
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
|
||||
CMPQ CX, $8
|
||||
JLT cmp1
|
||||
MOVQ (SI), AX
|
||||
MOVQ (DI), BX
|
||||
CMPQ AX, BX
|
||||
JNE bsf
|
||||
ADDQ $8, SI
|
||||
ADDQ $8, DI
|
||||
SUBQ $8, CX
|
||||
JMP cmp8
|
||||
|
||||
bsf:
|
||||
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
|
||||
// the index of the first byte that differs. The BSF instruction finds the
|
||||
// least significant 1 bit, the amd64 architecture is little-endian, and
|
||||
// the shift by 3 converts a bit index to a byte index.
|
||||
XORQ AX, BX
|
||||
BSFQ BX, BX
|
||||
SHRQ $3, BX
|
||||
ADDQ BX, DI
|
||||
|
||||
// Subtract off &b[0] to convert from &b[ret] to ret, and return.
|
||||
SUBQ DX, DI
|
||||
MOVQ DI, ret+56(FP)
|
||||
RET
|
||||
|
||||
cmp1:
|
||||
// In the slices' tail, compare 1 byte at a time.
|
||||
CMPQ CX, $0
|
||||
JEQ matchLenEnd
|
||||
MOVB (SI), AX
|
||||
MOVB (DI), BX
|
||||
CMPB AX, BX
|
||||
JNE matchLenEnd
|
||||
ADDQ $1, SI
|
||||
ADDQ $1, DI
|
||||
SUBQ $1, CX
|
||||
JMP cmp1
|
||||
|
||||
matchLenEnd:
|
||||
// Subtract off &b[0] to convert from &b[ret] to ret, and return.
|
||||
SUBQ DX, DI
|
||||
MOVQ DI, ret+56(FP)
|
||||
RET
|
||||
|
||||
// func histogram(b []byte, h []int32)
|
||||
TEXT ·histogram(SB), 4, $0
|
||||
MOVQ b+0(FP), SI // SI: &b
|
||||
MOVQ b_len+8(FP), R9 // R9: len(b)
|
||||
MOVQ h+24(FP), DI // DI: Histogram
|
||||
MOVQ R9, R8
|
||||
SHRQ $3, R8
|
||||
JZ hist1
|
||||
XORQ R11, R11
|
||||
|
||||
loop_hist8:
|
||||
MOVQ (SI), R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
MOVB R10, R11
|
||||
INCL (DI)(R11*4)
|
||||
SHRQ $8, R10
|
||||
|
||||
INCL (DI)(R10*4)
|
||||
|
||||
ADDQ $8, SI
|
||||
DECQ R8
|
||||
JNZ loop_hist8
|
||||
|
||||
hist1:
|
||||
ANDQ $7, R9
|
||||
JZ end_hist
|
||||
XORQ R10, R10
|
||||
|
||||
loop_hist1:
|
||||
MOVB (SI), R10
|
||||
INCL (DI)(R10*4)
|
||||
INCQ SI
|
||||
DECQ R9
|
||||
JNZ loop_hist1
|
||||
|
||||
end_hist:
|
||||
RET
|
|
@ -0,0 +1,35 @@
|
|||
//+build !amd64 noasm appengine
|
||||
|
||||
// Copyright 2015, Klaus Post, see LICENSE for details.
|
||||
|
||||
package flate
|
||||
|
||||
func init() {
|
||||
useSSE42 = false
|
||||
}
|
||||
|
||||
// crc32sse should never be called.
|
||||
func crc32sse(a []byte) uint32 {
|
||||
panic("no assembler")
|
||||
}
|
||||
|
||||
// crc32sseAll should never be called.
|
||||
func crc32sseAll(a []byte, dst []uint32) {
|
||||
panic("no assembler")
|
||||
}
|
||||
|
||||
// matchLenSSE4 should never be called.
|
||||
func matchLenSSE4(a, b []byte, max int) int {
|
||||
panic("no assembler")
|
||||
return 0
|
||||
}
|
||||
|
||||
// histogram accumulates a histogram of b in h.
|
||||
//
|
||||
// len(h) must be >= 256, and h's elements must be all zeroes.
|
||||
func histogram(b []byte, h []int32) {
|
||||
h = h[:256]
|
||||
for _, t := range b {
|
||||
h[t]++
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,184 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
// dictDecoder implements the LZ77 sliding dictionary as used in decompression.
|
||||
// LZ77 decompresses data through sequences of two forms of commands:
|
||||
//
|
||||
// * Literal insertions: Runs of one or more symbols are inserted into the data
|
||||
// stream as is. This is accomplished through the writeByte method for a
|
||||
// single symbol, or combinations of writeSlice/writeMark for multiple symbols.
|
||||
// Any valid stream must start with a literal insertion if no preset dictionary
|
||||
// is used.
|
||||
//
|
||||
// * Backward copies: Runs of one or more symbols are copied from previously
|
||||
// emitted data. Backward copies come as the tuple (dist, length) where dist
|
||||
// determines how far back in the stream to copy from and length determines how
|
||||
// many bytes to copy. Note that it is valid for the length to be greater than
|
||||
// the distance. Since LZ77 uses forward copies, that situation is used to
|
||||
// perform a form of run-length encoding on repeated runs of symbols.
|
||||
// The writeCopy and tryWriteCopy are used to implement this command.
|
||||
//
|
||||
// For performance reasons, this implementation performs little to no sanity
|
||||
// checks about the arguments. As such, the invariants documented for each
|
||||
// method call must be respected.
|
||||
type dictDecoder struct {
|
||||
hist []byte // Sliding window history
|
||||
|
||||
// Invariant: 0 <= rdPos <= wrPos <= len(hist)
|
||||
wrPos int // Current output position in buffer
|
||||
rdPos int // Have emitted hist[:rdPos] already
|
||||
full bool // Has a full window length been written yet?
|
||||
}
|
||||
|
||||
// init initializes dictDecoder to have a sliding window dictionary of the given
|
||||
// size. If a preset dict is provided, it will initialize the dictionary with
|
||||
// the contents of dict.
|
||||
func (dd *dictDecoder) init(size int, dict []byte) {
|
||||
*dd = dictDecoder{hist: dd.hist}
|
||||
|
||||
if cap(dd.hist) < size {
|
||||
dd.hist = make([]byte, size)
|
||||
}
|
||||
dd.hist = dd.hist[:size]
|
||||
|
||||
if len(dict) > len(dd.hist) {
|
||||
dict = dict[len(dict)-len(dd.hist):]
|
||||
}
|
||||
dd.wrPos = copy(dd.hist, dict)
|
||||
if dd.wrPos == len(dd.hist) {
|
||||
dd.wrPos = 0
|
||||
dd.full = true
|
||||
}
|
||||
dd.rdPos = dd.wrPos
|
||||
}
|
||||
|
||||
// histSize reports the total amount of historical data in the dictionary.
|
||||
func (dd *dictDecoder) histSize() int {
|
||||
if dd.full {
|
||||
return len(dd.hist)
|
||||
}
|
||||
return dd.wrPos
|
||||
}
|
||||
|
||||
// availRead reports the number of bytes that can be flushed by readFlush.
|
||||
func (dd *dictDecoder) availRead() int {
|
||||
return dd.wrPos - dd.rdPos
|
||||
}
|
||||
|
||||
// availWrite reports the available amount of output buffer space.
|
||||
func (dd *dictDecoder) availWrite() int {
|
||||
return len(dd.hist) - dd.wrPos
|
||||
}
|
||||
|
||||
// writeSlice returns a slice of the available buffer to write data to.
|
||||
//
|
||||
// This invariant will be kept: len(s) <= availWrite()
|
||||
func (dd *dictDecoder) writeSlice() []byte {
|
||||
return dd.hist[dd.wrPos:]
|
||||
}
|
||||
|
||||
// writeMark advances the writer pointer by cnt.
|
||||
//
|
||||
// This invariant must be kept: 0 <= cnt <= availWrite()
|
||||
func (dd *dictDecoder) writeMark(cnt int) {
|
||||
dd.wrPos += cnt
|
||||
}
|
||||
|
||||
// writeByte writes a single byte to the dictionary.
|
||||
//
|
||||
// This invariant must be kept: 0 < availWrite()
|
||||
func (dd *dictDecoder) writeByte(c byte) {
|
||||
dd.hist[dd.wrPos] = c
|
||||
dd.wrPos++
|
||||
}
|
||||
|
||||
// writeCopy copies a string at a given (dist, length) to the output.
|
||||
// This returns the number of bytes copied and may be less than the requested
|
||||
// length if the available space in the output buffer is too small.
|
||||
//
|
||||
// This invariant must be kept: 0 < dist <= histSize()
|
||||
func (dd *dictDecoder) writeCopy(dist, length int) int {
|
||||
dstBase := dd.wrPos
|
||||
dstPos := dstBase
|
||||
srcPos := dstPos - dist
|
||||
endPos := dstPos + length
|
||||
if endPos > len(dd.hist) {
|
||||
endPos = len(dd.hist)
|
||||
}
|
||||
|
||||
// Copy non-overlapping section after destination position.
|
||||
//
|
||||
// This section is non-overlapping in that the copy length for this section
|
||||
// is always less than or equal to the backwards distance. This can occur
|
||||
// if a distance refers to data that wraps-around in the buffer.
|
||||
// Thus, a backwards copy is performed here; that is, the exact bytes in
|
||||
// the source prior to the copy is placed in the destination.
|
||||
if srcPos < 0 {
|
||||
srcPos += len(dd.hist)
|
||||
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:])
|
||||
srcPos = 0
|
||||
}
|
||||
|
||||
// Copy possibly overlapping section before destination position.
|
||||
//
|
||||
// This section can overlap if the copy length for this section is larger
|
||||
// than the backwards distance. This is allowed by LZ77 so that repeated
|
||||
// strings can be succinctly represented using (dist, length) pairs.
|
||||
// Thus, a forwards copy is performed here; that is, the bytes copied is
|
||||
// possibly dependent on the resulting bytes in the destination as the copy
|
||||
// progresses along. This is functionally equivalent to the following:
|
||||
//
|
||||
// for i := 0; i < endPos-dstPos; i++ {
|
||||
// dd.hist[dstPos+i] = dd.hist[srcPos+i]
|
||||
// }
|
||||
// dstPos = endPos
|
||||
//
|
||||
for dstPos < endPos {
|
||||
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
|
||||
}
|
||||
|
||||
dd.wrPos = dstPos
|
||||
return dstPos - dstBase
|
||||
}
|
||||
|
||||
// tryWriteCopy tries to copy a string at a given (distance, length) to the
|
||||
// output. This specialized version is optimized for short distances.
|
||||
//
|
||||
// This method is designed to be inlined for performance reasons.
|
||||
//
|
||||
// This invariant must be kept: 0 < dist <= histSize()
|
||||
func (dd *dictDecoder) tryWriteCopy(dist, length int) int {
|
||||
dstPos := dd.wrPos
|
||||
endPos := dstPos + length
|
||||
if dstPos < dist || endPos > len(dd.hist) {
|
||||
return 0
|
||||
}
|
||||
dstBase := dstPos
|
||||
srcPos := dstPos - dist
|
||||
|
||||
// Copy possibly overlapping section before destination position.
|
||||
loop:
|
||||
dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos])
|
||||
if dstPos < endPos {
|
||||
goto loop // Avoid for-loop so that this function can be inlined
|
||||
}
|
||||
|
||||
dd.wrPos = dstPos
|
||||
return dstPos - dstBase
|
||||
}
|
||||
|
||||
// readFlush returns a slice of the historical buffer that is ready to be
|
||||
// emitted to the user. The data returned by readFlush must be fully consumed
|
||||
// before calling any other dictDecoder methods.
|
||||
func (dd *dictDecoder) readFlush() []byte {
|
||||
toRead := dd.hist[dd.rdPos:dd.wrPos]
|
||||
dd.rdPos = dd.wrPos
|
||||
if dd.wrPos == len(dd.hist) {
|
||||
dd.wrPos, dd.rdPos = 0, 0
|
||||
dd.full = true
|
||||
}
|
||||
return toRead
|
||||
}
|
|
@ -0,0 +1,265 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build ignore
|
||||
|
||||
// This program generates fixedhuff.go
|
||||
// Invoke as
|
||||
//
|
||||
// go run gen.go -output fixedhuff.go
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
)
|
||||
|
||||
var filename = flag.String("output", "fixedhuff.go", "output file name")
|
||||
|
||||
const maxCodeLen = 16
|
||||
|
||||
// Note: the definition of the huffmanDecoder struct is copied from
|
||||
// inflate.go, as it is private to the implementation.
|
||||
|
||||
// chunk & 15 is number of bits
|
||||
// chunk >> 4 is value, including table link
|
||||
|
||||
const (
|
||||
huffmanChunkBits = 9
|
||||
huffmanNumChunks = 1 << huffmanChunkBits
|
||||
huffmanCountMask = 15
|
||||
huffmanValueShift = 4
|
||||
)
|
||||
|
||||
type huffmanDecoder struct {
|
||||
min int // the minimum code length
|
||||
chunks [huffmanNumChunks]uint32 // chunks as described above
|
||||
links [][]uint32 // overflow links
|
||||
linkMask uint32 // mask the width of the link table
|
||||
}
|
||||
|
||||
// Initialize Huffman decoding tables from array of code lengths.
|
||||
// Following this function, h is guaranteed to be initialized into a complete
|
||||
// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a
|
||||
// degenerate case where the tree has only a single symbol with length 1. Empty
|
||||
// trees are permitted.
|
||||
func (h *huffmanDecoder) init(bits []int) bool {
|
||||
// Sanity enables additional runtime tests during Huffman
|
||||
// table construction. It's intended to be used during
|
||||
// development to supplement the currently ad-hoc unit tests.
|
||||
const sanity = false
|
||||
|
||||
if h.min != 0 {
|
||||
*h = huffmanDecoder{}
|
||||
}
|
||||
|
||||
// Count number of codes of each length,
|
||||
// compute min and max length.
|
||||
var count [maxCodeLen]int
|
||||
var min, max int
|
||||
for _, n := range bits {
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
if min == 0 || n < min {
|
||||
min = n
|
||||
}
|
||||
if n > max {
|
||||
max = n
|
||||
}
|
||||
count[n]++
|
||||
}
|
||||
|
||||
// Empty tree. The decompressor.huffSym function will fail later if the tree
|
||||
// is used. Technically, an empty tree is only valid for the HDIST tree and
|
||||
// not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree
|
||||
// is guaranteed to fail since it will attempt to use the tree to decode the
|
||||
// codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is
|
||||
// guaranteed to fail later since the compressed data section must be
|
||||
// composed of at least one symbol (the end-of-block marker).
|
||||
if max == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
code := 0
|
||||
var nextcode [maxCodeLen]int
|
||||
for i := min; i <= max; i++ {
|
||||
code <<= 1
|
||||
nextcode[i] = code
|
||||
code += count[i]
|
||||
}
|
||||
|
||||
// Check that the coding is complete (i.e., that we've
|
||||
// assigned all 2-to-the-max possible bit sequences).
|
||||
// Exception: To be compatible with zlib, we also need to
|
||||
// accept degenerate single-code codings. See also
|
||||
// TestDegenerateHuffmanCoding.
|
||||
if code != 1<<uint(max) && !(code == 1 && max == 1) {
|
||||
return false
|
||||
}
|
||||
|
||||
h.min = min
|
||||
if max > huffmanChunkBits {
|
||||
numLinks := 1 << (uint(max) - huffmanChunkBits)
|
||||
h.linkMask = uint32(numLinks - 1)
|
||||
|
||||
// create link tables
|
||||
link := nextcode[huffmanChunkBits+1] >> 1
|
||||
h.links = make([][]uint32, huffmanNumChunks-link)
|
||||
for j := uint(link); j < huffmanNumChunks; j++ {
|
||||
reverse := int(reverseByte[j>>8]) | int(reverseByte[j&0xff])<<8
|
||||
reverse >>= uint(16 - huffmanChunkBits)
|
||||
off := j - uint(link)
|
||||
if sanity && h.chunks[reverse] != 0 {
|
||||
panic("impossible: overwriting existing chunk")
|
||||
}
|
||||
h.chunks[reverse] = uint32(off<<huffmanValueShift | (huffmanChunkBits + 1))
|
||||
h.links[off] = make([]uint32, numLinks)
|
||||
}
|
||||
}
|
||||
|
||||
for i, n := range bits {
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
code := nextcode[n]
|
||||
nextcode[n]++
|
||||
chunk := uint32(i<<huffmanValueShift | n)
|
||||
reverse := int(reverseByte[code>>8]) | int(reverseByte[code&0xff])<<8
|
||||
reverse >>= uint(16 - n)
|
||||
if n <= huffmanChunkBits {
|
||||
for off := reverse; off < len(h.chunks); off += 1 << uint(n) {
|
||||
// We should never need to overwrite
|
||||
// an existing chunk. Also, 0 is
|
||||
// never a valid chunk, because the
|
||||
// lower 4 "count" bits should be
|
||||
// between 1 and 15.
|
||||
if sanity && h.chunks[off] != 0 {
|
||||
panic("impossible: overwriting existing chunk")
|
||||
}
|
||||
h.chunks[off] = chunk
|
||||
}
|
||||
} else {
|
||||
j := reverse & (huffmanNumChunks - 1)
|
||||
if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 {
|
||||
// Longer codes should have been
|
||||
// associated with a link table above.
|
||||
panic("impossible: not an indirect chunk")
|
||||
}
|
||||
value := h.chunks[j] >> huffmanValueShift
|
||||
linktab := h.links[value]
|
||||
reverse >>= huffmanChunkBits
|
||||
for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) {
|
||||
if sanity && linktab[off] != 0 {
|
||||
panic("impossible: overwriting existing chunk")
|
||||
}
|
||||
linktab[off] = chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sanity {
|
||||
// Above we've sanity checked that we never overwrote
|
||||
// an existing entry. Here we additionally check that
|
||||
// we filled the tables completely.
|
||||
for i, chunk := range h.chunks {
|
||||
if chunk == 0 {
|
||||
// As an exception, in the degenerate
|
||||
// single-code case, we allow odd
|
||||
// chunks to be missing.
|
||||
if code == 1 && i%2 == 1 {
|
||||
continue
|
||||
}
|
||||
panic("impossible: missing chunk")
|
||||
}
|
||||
}
|
||||
for _, linktab := range h.links {
|
||||
for _, chunk := range linktab {
|
||||
if chunk == 0 {
|
||||
panic("impossible: missing chunk")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
var h huffmanDecoder
|
||||
var bits [288]int
|
||||
initReverseByte()
|
||||
for i := 0; i < 144; i++ {
|
||||
bits[i] = 8
|
||||
}
|
||||
for i := 144; i < 256; i++ {
|
||||
bits[i] = 9
|
||||
}
|
||||
for i := 256; i < 280; i++ {
|
||||
bits[i] = 7
|
||||
}
|
||||
for i := 280; i < 288; i++ {
|
||||
bits[i] = 8
|
||||
}
|
||||
h.init(bits[:])
|
||||
if h.links != nil {
|
||||
log.Fatal("Unexpected links table in fixed Huffman decoder")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
fmt.Fprintf(&buf, `// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.`+"\n\n")
|
||||
|
||||
fmt.Fprintln(&buf, "package flate")
|
||||
fmt.Fprintln(&buf)
|
||||
fmt.Fprintln(&buf, "// autogenerated by go run gen.go -output fixedhuff.go, DO NOT EDIT")
|
||||
fmt.Fprintln(&buf)
|
||||
fmt.Fprintln(&buf, "var fixedHuffmanDecoder = huffmanDecoder{")
|
||||
fmt.Fprintf(&buf, "\t%d,\n", h.min)
|
||||
fmt.Fprintln(&buf, "\t[huffmanNumChunks]uint32{")
|
||||
for i := 0; i < huffmanNumChunks; i++ {
|
||||
if i&7 == 0 {
|
||||
fmt.Fprintf(&buf, "\t\t")
|
||||
} else {
|
||||
fmt.Fprintf(&buf, " ")
|
||||
}
|
||||
fmt.Fprintf(&buf, "0x%04x,", h.chunks[i])
|
||||
if i&7 == 7 {
|
||||
fmt.Fprintln(&buf)
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(&buf, "\t},")
|
||||
fmt.Fprintln(&buf, "\tnil, 0,")
|
||||
fmt.Fprintln(&buf, "}")
|
||||
|
||||
data, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
err = ioutil.WriteFile(*filename, data, 0644)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var reverseByte [256]byte
|
||||
|
||||
func initReverseByte() {
|
||||
for x := 0; x < 256; x++ {
|
||||
var result byte
|
||||
for i := uint(0); i < 8; i++ {
|
||||
result |= byte(((x >> i) & 1) << (7 - i))
|
||||
}
|
||||
reverseByte[x] = result
|
||||
}
|
||||
}
|
|
@ -0,0 +1,701 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// The largest offset code.
|
||||
offsetCodeCount = 30
|
||||
|
||||
// The special code used to mark the end of a block.
|
||||
endBlockMarker = 256
|
||||
|
||||
// The first length code.
|
||||
lengthCodesStart = 257
|
||||
|
||||
// The number of codegen codes.
|
||||
codegenCodeCount = 19
|
||||
badCode = 255
|
||||
|
||||
// bufferFlushSize indicates the buffer size
|
||||
// after which bytes are flushed to the writer.
|
||||
// Should preferably be a multiple of 6, since
|
||||
// we accumulate 6 bytes between writes to the buffer.
|
||||
bufferFlushSize = 240
|
||||
|
||||
// bufferSize is the actual output byte buffer size.
|
||||
// It must have additional headroom for a flush
|
||||
// which can contain up to 8 bytes.
|
||||
bufferSize = bufferFlushSize + 8
|
||||
)
|
||||
|
||||
// The number of extra bits needed by length code X - LENGTH_CODES_START.
|
||||
var lengthExtraBits = []int8{
|
||||
/* 257 */ 0, 0, 0,
|
||||
/* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
|
||||
/* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
|
||||
/* 280 */ 4, 5, 5, 5, 5, 0,
|
||||
}
|
||||
|
||||
// The length indicated by length code X - LENGTH_CODES_START.
|
||||
var lengthBase = []uint32{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 10,
|
||||
12, 14, 16, 20, 24, 28, 32, 40, 48, 56,
|
||||
64, 80, 96, 112, 128, 160, 192, 224, 255,
|
||||
}
|
||||
|
||||
// offset code word extra bits.
|
||||
var offsetExtraBits = []int8{
|
||||
0, 0, 0, 0, 1, 1, 2, 2, 3, 3,
|
||||
4, 4, 5, 5, 6, 6, 7, 7, 8, 8,
|
||||
9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
|
||||
/* extended window */
|
||||
14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20,
|
||||
}
|
||||
|
||||
var offsetBase = []uint32{
|
||||
/* normal deflate */
|
||||
0x000000, 0x000001, 0x000002, 0x000003, 0x000004,
|
||||
0x000006, 0x000008, 0x00000c, 0x000010, 0x000018,
|
||||
0x000020, 0x000030, 0x000040, 0x000060, 0x000080,
|
||||
0x0000c0, 0x000100, 0x000180, 0x000200, 0x000300,
|
||||
0x000400, 0x000600, 0x000800, 0x000c00, 0x001000,
|
||||
0x001800, 0x002000, 0x003000, 0x004000, 0x006000,
|
||||
|
||||
/* extended window */
|
||||
0x008000, 0x00c000, 0x010000, 0x018000, 0x020000,
|
||||
0x030000, 0x040000, 0x060000, 0x080000, 0x0c0000,
|
||||
0x100000, 0x180000, 0x200000, 0x300000,
|
||||
}
|
||||
|
||||
// The odd order in which the codegen code sizes are written.
|
||||
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
|
||||
|
||||
type huffmanBitWriter struct {
|
||||
// writer is the underlying writer.
|
||||
// Do not use it directly; use the write method, which ensures
|
||||
// that Write errors are sticky.
|
||||
writer io.Writer
|
||||
|
||||
// Data waiting to be written is bytes[0:nbytes]
|
||||
// and then the low nbits of bits.
|
||||
bits uint64
|
||||
nbits uint
|
||||
bytes [bufferSize]byte
|
||||
codegenFreq [codegenCodeCount]int32
|
||||
nbytes int
|
||||
literalFreq []int32
|
||||
offsetFreq []int32
|
||||
codegen []uint8
|
||||
literalEncoding *huffmanEncoder
|
||||
offsetEncoding *huffmanEncoder
|
||||
codegenEncoding *huffmanEncoder
|
||||
err error
|
||||
}
|
||||
|
||||
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
|
||||
return &huffmanBitWriter{
|
||||
writer: w,
|
||||
literalFreq: make([]int32, maxNumLit),
|
||||
offsetFreq: make([]int32, offsetCodeCount),
|
||||
codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
|
||||
literalEncoding: newHuffmanEncoder(maxNumLit),
|
||||
codegenEncoding: newHuffmanEncoder(codegenCodeCount),
|
||||
offsetEncoding: newHuffmanEncoder(offsetCodeCount),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) reset(writer io.Writer) {
|
||||
w.writer = writer
|
||||
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
|
||||
w.bytes = [bufferSize]byte{}
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) flush() {
|
||||
if w.err != nil {
|
||||
w.nbits = 0
|
||||
return
|
||||
}
|
||||
n := w.nbytes
|
||||
for w.nbits != 0 {
|
||||
w.bytes[n] = byte(w.bits)
|
||||
w.bits >>= 8
|
||||
if w.nbits > 8 { // Avoid underflow
|
||||
w.nbits -= 8
|
||||
} else {
|
||||
w.nbits = 0
|
||||
}
|
||||
n++
|
||||
}
|
||||
w.bits = 0
|
||||
w.write(w.bytes[:n])
|
||||
w.nbytes = 0
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) write(b []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
_, w.err = w.writer.Write(b)
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
w.bits |= uint64(b) << w.nbits
|
||||
w.nbits += nb
|
||||
if w.nbits >= 48 {
|
||||
bits := w.bits
|
||||
w.bits >>= 48
|
||||
w.nbits -= 48
|
||||
n := w.nbytes
|
||||
bytes := w.bytes[n : n+6]
|
||||
bytes[0] = byte(bits)
|
||||
bytes[1] = byte(bits >> 8)
|
||||
bytes[2] = byte(bits >> 16)
|
||||
bytes[3] = byte(bits >> 24)
|
||||
bytes[4] = byte(bits >> 32)
|
||||
bytes[5] = byte(bits >> 40)
|
||||
n += 6
|
||||
if n >= bufferFlushSize {
|
||||
w.write(w.bytes[:n])
|
||||
n = 0
|
||||
}
|
||||
w.nbytes = n
|
||||
}
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) writeBytes(bytes []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
n := w.nbytes
|
||||
if w.nbits&7 != 0 {
|
||||
w.err = InternalError("writeBytes with unfinished bits")
|
||||
return
|
||||
}
|
||||
for w.nbits != 0 {
|
||||
w.bytes[n] = byte(w.bits)
|
||||
w.bits >>= 8
|
||||
w.nbits -= 8
|
||||
n++
|
||||
}
|
||||
if n != 0 {
|
||||
w.write(w.bytes[:n])
|
||||
}
|
||||
w.nbytes = 0
|
||||
w.write(bytes)
|
||||
}
|
||||
|
||||
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
|
||||
// the literal and offset lengths arrays (which are concatenated into a single
|
||||
// array). This method generates that run-length encoding.
|
||||
//
|
||||
// The result is written into the codegen array, and the frequencies
|
||||
// of each code is written into the codegenFreq array.
|
||||
// Codes 0-15 are single byte codes. Codes 16-18 are followed by additional
|
||||
// information. Code badCode is an end marker
|
||||
//
|
||||
// numLiterals The number of literals in literalEncoding
|
||||
// numOffsets The number of offsets in offsetEncoding
|
||||
// litenc, offenc The literal and offset encoder to use
|
||||
func (w *huffmanBitWriter) generateCodegen(numLiterals int, numOffsets int, litEnc, offEnc *huffmanEncoder) {
|
||||
for i := range w.codegenFreq {
|
||||
w.codegenFreq[i] = 0
|
||||
}
|
||||
// Note that we are using codegen both as a temporary variable for holding
|
||||
// a copy of the frequencies, and as the place where we put the result.
|
||||
// This is fine because the output is always shorter than the input used
|
||||
// so far.
|
||||
codegen := w.codegen // cache
|
||||
// Copy the concatenated code sizes to codegen. Put a marker at the end.
|
||||
cgnl := codegen[:numLiterals]
|
||||
for i := range cgnl {
|
||||
cgnl[i] = uint8(litEnc.codes[i].len)
|
||||
}
|
||||
|
||||
cgnl = codegen[numLiterals : numLiterals+numOffsets]
|
||||
for i := range cgnl {
|
||||
cgnl[i] = uint8(offEnc.codes[i].len)
|
||||
}
|
||||
codegen[numLiterals+numOffsets] = badCode
|
||||
|
||||
size := codegen[0]
|
||||
count := 1
|
||||
outIndex := 0
|
||||
for inIndex := 1; size != badCode; inIndex++ {
|
||||
// INVARIANT: We have seen "count" copies of size that have not yet
|
||||
// had output generated for them.
|
||||
nextSize := codegen[inIndex]
|
||||
if nextSize == size {
|
||||
count++
|
||||
continue
|
||||
}
|
||||
// We need to generate codegen indicating "count" of size.
|
||||
if size != 0 {
|
||||
codegen[outIndex] = size
|
||||
outIndex++
|
||||
w.codegenFreq[size]++
|
||||
count--
|
||||
for count >= 3 {
|
||||
n := 6
|
||||
if n > count {
|
||||
n = count
|
||||
}
|
||||
codegen[outIndex] = 16
|
||||
outIndex++
|
||||
codegen[outIndex] = uint8(n - 3)
|
||||
outIndex++
|
||||
w.codegenFreq[16]++
|
||||
count -= n
|
||||
}
|
||||
} else {
|
||||
for count >= 11 {
|
||||
n := 138
|
||||
if n > count {
|
||||
n = count
|
||||
}
|
||||
codegen[outIndex] = 18
|
||||
outIndex++
|
||||
codegen[outIndex] = uint8(n - 11)
|
||||
outIndex++
|
||||
w.codegenFreq[18]++
|
||||
count -= n
|
||||
}
|
||||
if count >= 3 {
|
||||
// count >= 3 && count <= 10
|
||||
codegen[outIndex] = 17
|
||||
outIndex++
|
||||
codegen[outIndex] = uint8(count - 3)
|
||||
outIndex++
|
||||
w.codegenFreq[17]++
|
||||
count = 0
|
||||
}
|
||||
}
|
||||
count--
|
||||
for ; count >= 0; count-- {
|
||||
codegen[outIndex] = size
|
||||
outIndex++
|
||||
w.codegenFreq[size]++
|
||||
}
|
||||
// Set up invariant for next time through the loop.
|
||||
size = nextSize
|
||||
count = 1
|
||||
}
|
||||
// Marker indicating the end of the codegen.
|
||||
codegen[outIndex] = badCode
|
||||
}
|
||||
|
||||
// dynamicSize returns the size of dynamically encoded data in bits.
|
||||
func (w *huffmanBitWriter) dynamicSize(litEnc, offEnc *huffmanEncoder, extraBits int) (size, numCodegens int) {
|
||||
numCodegens = len(w.codegenFreq)
|
||||
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
|
||||
numCodegens--
|
||||
}
|
||||
header := 3 + 5 + 5 + 4 + (3 * numCodegens) +
|
||||
w.codegenEncoding.bitLength(w.codegenFreq[:]) +
|
||||
int(w.codegenFreq[16])*2 +
|
||||
int(w.codegenFreq[17])*3 +
|
||||
int(w.codegenFreq[18])*7
|
||||
size = header +
|
||||
litEnc.bitLength(w.literalFreq) +
|
||||
offEnc.bitLength(w.offsetFreq) +
|
||||
extraBits
|
||||
|
||||
return size, numCodegens
|
||||
}
|
||||
|
||||
// fixedSize returns the size of dynamically encoded data in bits.
|
||||
func (w *huffmanBitWriter) fixedSize(extraBits int) int {
|
||||
return 3 +
|
||||
fixedLiteralEncoding.bitLength(w.literalFreq) +
|
||||
fixedOffsetEncoding.bitLength(w.offsetFreq) +
|
||||
extraBits
|
||||
}
|
||||
|
||||
// storedSize calculates the stored size, including header.
|
||||
// The function returns the size in bits and whether the block
|
||||
// fits inside a single block.
|
||||
func (w *huffmanBitWriter) storedSize(in []byte) (int, bool) {
|
||||
if in == nil {
|
||||
return 0, false
|
||||
}
|
||||
if len(in) <= maxStoreBlockSize {
|
||||
return (len(in) + 5) * 8, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) writeCode(c hcode) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
w.bits |= uint64(c.code) << w.nbits
|
||||
w.nbits += uint(c.len)
|
||||
if w.nbits >= 48 {
|
||||
bits := w.bits
|
||||
w.bits >>= 48
|
||||
w.nbits -= 48
|
||||
n := w.nbytes
|
||||
bytes := w.bytes[n : n+6]
|
||||
bytes[0] = byte(bits)
|
||||
bytes[1] = byte(bits >> 8)
|
||||
bytes[2] = byte(bits >> 16)
|
||||
bytes[3] = byte(bits >> 24)
|
||||
bytes[4] = byte(bits >> 32)
|
||||
bytes[5] = byte(bits >> 40)
|
||||
n += 6
|
||||
if n >= bufferFlushSize {
|
||||
w.write(w.bytes[:n])
|
||||
n = 0
|
||||
}
|
||||
w.nbytes = n
|
||||
}
|
||||
}
|
||||
|
||||
// Write the header of a dynamic Huffman block to the output stream.
|
||||
//
|
||||
// numLiterals The number of literals specified in codegen
|
||||
// numOffsets The number of offsets specified in codegen
|
||||
// numCodegens The number of codegens used in codegen
|
||||
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
var firstBits int32 = 4
|
||||
if isEof {
|
||||
firstBits = 5
|
||||
}
|
||||
w.writeBits(firstBits, 3)
|
||||
w.writeBits(int32(numLiterals-257), 5)
|
||||
w.writeBits(int32(numOffsets-1), 5)
|
||||
w.writeBits(int32(numCodegens-4), 4)
|
||||
|
||||
for i := 0; i < numCodegens; i++ {
|
||||
value := uint(w.codegenEncoding.codes[codegenOrder[i]].len)
|
||||
w.writeBits(int32(value), 3)
|
||||
}
|
||||
|
||||
i := 0
|
||||
for {
|
||||
var codeWord int = int(w.codegen[i])
|
||||
i++
|
||||
if codeWord == badCode {
|
||||
break
|
||||
}
|
||||
w.writeCode(w.codegenEncoding.codes[uint32(codeWord)])
|
||||
|
||||
switch codeWord {
|
||||
case 16:
|
||||
w.writeBits(int32(w.codegen[i]), 2)
|
||||
i++
|
||||
break
|
||||
case 17:
|
||||
w.writeBits(int32(w.codegen[i]), 3)
|
||||
i++
|
||||
break
|
||||
case 18:
|
||||
w.writeBits(int32(w.codegen[i]), 7)
|
||||
i++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) writeStoredHeader(length int, isEof bool) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
var flag int32
|
||||
if isEof {
|
||||
flag = 1
|
||||
}
|
||||
w.writeBits(flag, 3)
|
||||
w.flush()
|
||||
w.writeBits(int32(length), 16)
|
||||
w.writeBits(int32(^uint16(length)), 16)
|
||||
}
|
||||
|
||||
func (w *huffmanBitWriter) writeFixedHeader(isEof bool) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
// Indicate that we are a fixed Huffman block
|
||||
var value int32 = 2
|
||||
if isEof {
|
||||
value = 3
|
||||
}
|
||||
w.writeBits(value, 3)
|
||||
}
|
||||
|
||||
// writeBlock will write a block of tokens with the smallest encoding.
|
||||
// The original input can be supplied, and if the huffman encoded data
|
||||
// is larger than the original bytes, the data will be written as a
|
||||
// stored block.
|
||||
// If the input is nil, the tokens will always be Huffman encoded.
|
||||
func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tokens = append(tokens, endBlockMarker)
|
||||
numLiterals, numOffsets := w.indexTokens(tokens)
|
||||
|
||||
var extraBits int
|
||||
storedSize, storable := w.storedSize(input)
|
||||
if storable {
|
||||
// We only bother calculating the costs of the extra bits required by
|
||||
// the length of offset fields (which will be the same for both fixed
|
||||
// and dynamic encoding), if we need to compare those two encodings
|
||||
// against stored encoding.
|
||||
for lengthCode := lengthCodesStart + 8; lengthCode < numLiterals; lengthCode++ {
|
||||
// First eight length codes have extra size = 0.
|
||||
extraBits += int(w.literalFreq[lengthCode]) * int(lengthExtraBits[lengthCode-lengthCodesStart])
|
||||
}
|
||||
for offsetCode := 4; offsetCode < numOffsets; offsetCode++ {
|
||||
// First four offset codes have extra size = 0.
|
||||
extraBits += int(w.offsetFreq[offsetCode]) * int(offsetExtraBits[offsetCode])
|
||||
}
|
||||
}
|
||||
|
||||
// Figure out smallest code.
|
||||
// Fixed Huffman baseline.
|
||||
var literalEncoding = fixedLiteralEncoding
|
||||
var offsetEncoding = fixedOffsetEncoding
|
||||
var size = w.fixedSize(extraBits)
|
||||
|
||||
// Dynamic Huffman?
|
||||
var numCodegens int
|
||||
|
||||
// Generate codegen and codegenFrequencies, which indicates how to encode
|
||||
// the literalEncoding and the offsetEncoding.
|
||||
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding)
|
||||
w.codegenEncoding.generate(w.codegenFreq[:], 7)
|
||||
dynamicSize, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, extraBits)
|
||||
|
||||
if dynamicSize < size {
|
||||
size = dynamicSize
|
||||
literalEncoding = w.literalEncoding
|
||||
offsetEncoding = w.offsetEncoding
|
||||
}
|
||||
|
||||
// Stored bytes?
|
||||
if storable && storedSize < size {
|
||||
w.writeStoredHeader(len(input), eof)
|
||||
w.writeBytes(input)
|
||||
return
|
||||
}
|
||||
|
||||
// Huffman.
|
||||
if literalEncoding == fixedLiteralEncoding {
|
||||
w.writeFixedHeader(eof)
|
||||
} else {
|
||||
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
|
||||
}
|
||||
|
||||
// Write the tokens.
|
||||
w.writeTokens(tokens, literalEncoding.codes, offsetEncoding.codes)
|
||||
}
|
||||
|
||||
// writeBlockDynamic encodes a block using a dynamic Huffman table.
|
||||
// This should be used if the symbols used have a disproportionate
|
||||
// histogram distribution.
|
||||
// If input is supplied and the compression savings are below 1/16th of the
|
||||
// input size the block is stored.
|
||||
func (w *huffmanBitWriter) writeBlockDynamic(tokens []token, eof bool, input []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tokens = append(tokens, endBlockMarker)
|
||||
numLiterals, numOffsets := w.indexTokens(tokens)
|
||||
|
||||
// Generate codegen and codegenFrequencies, which indicates how to encode
|
||||
// the literalEncoding and the offsetEncoding.
|
||||
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding)
|
||||
w.codegenEncoding.generate(w.codegenFreq[:], 7)
|
||||
size, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, 0)
|
||||
|
||||
// Store bytes, if we don't get a reasonable improvement.
|
||||
if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) {
|
||||
w.writeStoredHeader(len(input), eof)
|
||||
w.writeBytes(input)
|
||||
return
|
||||
}
|
||||
|
||||
// Write Huffman table.
|
||||
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
|
||||
|
||||
// Write the tokens.
|
||||
w.writeTokens(tokens, w.literalEncoding.codes, w.offsetEncoding.codes)
|
||||
}
|
||||
|
||||
// indexTokens indexes a slice of tokens, and updates
|
||||
// literalFreq and offsetFreq, and generates literalEncoding
|
||||
// and offsetEncoding.
|
||||
// The number of literal and offset tokens is returned.
|
||||
func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets int) {
|
||||
for i := range w.literalFreq {
|
||||
w.literalFreq[i] = 0
|
||||
}
|
||||
for i := range w.offsetFreq {
|
||||
w.offsetFreq[i] = 0
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
if t < matchType {
|
||||
w.literalFreq[t.literal()]++
|
||||
continue
|
||||
}
|
||||
length := t.length()
|
||||
offset := t.offset()
|
||||
w.literalFreq[lengthCodesStart+lengthCode(length)]++
|
||||
w.offsetFreq[offsetCode(offset)]++
|
||||
}
|
||||
|
||||
// get the number of literals
|
||||
numLiterals = len(w.literalFreq)
|
||||
for w.literalFreq[numLiterals-1] == 0 {
|
||||
numLiterals--
|
||||
}
|
||||
// get the number of offsets
|
||||
numOffsets = len(w.offsetFreq)
|
||||
for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
|
||||
numOffsets--
|
||||
}
|
||||
if numOffsets == 0 {
|
||||
// We haven't found a single match. If we want to go with the dynamic encoding,
|
||||
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
|
||||
w.offsetFreq[0] = 1
|
||||
numOffsets = 1
|
||||
}
|
||||
w.literalEncoding.generate(w.literalFreq, 15)
|
||||
w.offsetEncoding.generate(w.offsetFreq, 15)
|
||||
return
|
||||
}
|
||||
|
||||
// writeTokens writes a slice of tokens to the output.
|
||||
// codes for literal and offset encoding must be supplied.
|
||||
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
for _, t := range tokens {
|
||||
if t < matchType {
|
||||
w.writeCode(leCodes[t.literal()])
|
||||
continue
|
||||
}
|
||||
// Write the length
|
||||
length := t.length()
|
||||
lengthCode := lengthCode(length)
|
||||
w.writeCode(leCodes[lengthCode+lengthCodesStart])
|
||||
extraLengthBits := uint(lengthExtraBits[lengthCode])
|
||||
if extraLengthBits > 0 {
|
||||
extraLength := int32(length - lengthBase[lengthCode])
|
||||
w.writeBits(extraLength, extraLengthBits)
|
||||
}
|
||||
// Write the offset
|
||||
offset := t.offset()
|
||||
offsetCode := offsetCode(offset)
|
||||
w.writeCode(oeCodes[offsetCode])
|
||||
extraOffsetBits := uint(offsetExtraBits[offsetCode])
|
||||
if extraOffsetBits > 0 {
|
||||
extraOffset := int32(offset - offsetBase[offsetCode])
|
||||
w.writeBits(extraOffset, extraOffsetBits)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// huffOffset is a static offset encoder used for huffman only encoding.
|
||||
// It can be reused since we will not be encoding offset values.
|
||||
var huffOffset *huffmanEncoder
|
||||
|
||||
func init() {
|
||||
w := newHuffmanBitWriter(nil)
|
||||
w.offsetFreq[0] = 1
|
||||
huffOffset = newHuffmanEncoder(offsetCodeCount)
|
||||
huffOffset.generate(w.offsetFreq, 15)
|
||||
}
|
||||
|
||||
// writeBlockHuff encodes a block of bytes as either
|
||||
// Huffman encoded literals or uncompressed bytes if the
|
||||
// results only gains very little from compression.
|
||||
func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear histogram
|
||||
for i := range w.literalFreq {
|
||||
w.literalFreq[i] = 0
|
||||
}
|
||||
|
||||
// Add everything as literals
|
||||
histogram(input, w.literalFreq)
|
||||
|
||||
w.literalFreq[endBlockMarker] = 1
|
||||
|
||||
const numLiterals = endBlockMarker + 1
|
||||
const numOffsets = 1
|
||||
|
||||
w.literalEncoding.generate(w.literalFreq, 15)
|
||||
|
||||
// Figure out smallest code.
|
||||
// Always use dynamic Huffman or Store
|
||||
var numCodegens int
|
||||
|
||||
// Generate codegen and codegenFrequencies, which indicates how to encode
|
||||
// the literalEncoding and the offsetEncoding.
|
||||
w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, huffOffset)
|
||||
w.codegenEncoding.generate(w.codegenFreq[:], 7)
|
||||
size, numCodegens := w.dynamicSize(w.literalEncoding, huffOffset, 0)
|
||||
|
||||
// Store bytes, if we don't get a reasonable improvement.
|
||||
if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) {
|
||||
w.writeStoredHeader(len(input), eof)
|
||||
w.writeBytes(input)
|
||||
return
|
||||
}
|
||||
|
||||
// Huffman.
|
||||
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
|
||||
encoding := w.literalEncoding.codes[:257]
|
||||
n := w.nbytes
|
||||
for _, t := range input {
|
||||
// Bitwriting inlined, ~30% speedup
|
||||
c := encoding[t]
|
||||
w.bits |= uint64(c.code) << w.nbits
|
||||
w.nbits += uint(c.len)
|
||||
if w.nbits < 48 {
|
||||
continue
|
||||
}
|
||||
// Store 6 bytes
|
||||
bits := w.bits
|
||||
w.bits >>= 48
|
||||
w.nbits -= 48
|
||||
bytes := w.bytes[n : n+6]
|
||||
bytes[0] = byte(bits)
|
||||
bytes[1] = byte(bits >> 8)
|
||||
bytes[2] = byte(bits >> 16)
|
||||
bytes[3] = byte(bits >> 24)
|
||||
bytes[4] = byte(bits >> 32)
|
||||
bytes[5] = byte(bits >> 40)
|
||||
n += 6
|
||||
if n < bufferFlushSize {
|
||||
continue
|
||||
}
|
||||
w.write(w.bytes[:n])
|
||||
if w.err != nil {
|
||||
return // Return early in the event of write failures
|
||||
}
|
||||
n = 0
|
||||
}
|
||||
w.nbytes = n
|
||||
w.writeCode(encoding[endBlockMarker])
|
||||
}
|
|
@ -0,0 +1,344 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// hcode is a huffman code with a bit code and bit length.
|
||||
type hcode struct {
|
||||
code, len uint16
|
||||
}
|
||||
|
||||
type huffmanEncoder struct {
|
||||
codes []hcode
|
||||
freqcache []literalNode
|
||||
bitCount [17]int32
|
||||
lns byLiteral // stored to avoid repeated allocation in generate
|
||||
lfs byFreq // stored to avoid repeated allocation in generate
|
||||
}
|
||||
|
||||
type literalNode struct {
|
||||
literal uint16
|
||||
freq int32
|
||||
}
|
||||
|
||||
// A levelInfo describes the state of the constructed tree for a given depth.
|
||||
type levelInfo struct {
|
||||
// Our level. for better printing
|
||||
level int32
|
||||
|
||||
// The frequency of the last node at this level
|
||||
lastFreq int32
|
||||
|
||||
// The frequency of the next character to add to this level
|
||||
nextCharFreq int32
|
||||
|
||||
// The frequency of the next pair (from level below) to add to this level.
|
||||
// Only valid if the "needed" value of the next lower level is 0.
|
||||
nextPairFreq int32
|
||||
|
||||
// The number of chains remaining to generate for this level before moving
|
||||
// up to the next level
|
||||
needed int32
|
||||
}
|
||||
|
||||
// set sets the code and length of an hcode.
|
||||
func (h *hcode) set(code uint16, length uint16) {
|
||||
h.len = length
|
||||
h.code = code
|
||||
}
|
||||
|
||||
func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} }
|
||||
|
||||
func newHuffmanEncoder(size int) *huffmanEncoder {
|
||||
return &huffmanEncoder{codes: make([]hcode, size)}
|
||||
}
|
||||
|
||||
// Generates a HuffmanCode corresponding to the fixed literal table
|
||||
func generateFixedLiteralEncoding() *huffmanEncoder {
|
||||
h := newHuffmanEncoder(maxNumLit)
|
||||
codes := h.codes
|
||||
var ch uint16
|
||||
for ch = 0; ch < maxNumLit; ch++ {
|
||||
var bits uint16
|
||||
var size uint16
|
||||
switch {
|
||||
case ch < 144:
|
||||
// size 8, 000110000 .. 10111111
|
||||
bits = ch + 48
|
||||
size = 8
|
||||
break
|
||||
case ch < 256:
|
||||
// size 9, 110010000 .. 111111111
|
||||
bits = ch + 400 - 144
|
||||
size = 9
|
||||
break
|
||||
case ch < 280:
|
||||
// size 7, 0000000 .. 0010111
|
||||
bits = ch - 256
|
||||
size = 7
|
||||
break
|
||||
default:
|
||||
// size 8, 11000000 .. 11000111
|
||||
bits = ch + 192 - 280
|
||||
size = 8
|
||||
}
|
||||
codes[ch] = hcode{code: reverseBits(bits, byte(size)), len: size}
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func generateFixedOffsetEncoding() *huffmanEncoder {
|
||||
h := newHuffmanEncoder(30)
|
||||
codes := h.codes
|
||||
for ch := range codes {
|
||||
codes[ch] = hcode{code: reverseBits(uint16(ch), 5), len: 5}
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
var fixedLiteralEncoding *huffmanEncoder = generateFixedLiteralEncoding()
|
||||
var fixedOffsetEncoding *huffmanEncoder = generateFixedOffsetEncoding()
|
||||
|
||||
func (h *huffmanEncoder) bitLength(freq []int32) int {
|
||||
var total int
|
||||
for i, f := range freq {
|
||||
if f != 0 {
|
||||
total += int(f) * int(h.codes[i].len)
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
const maxBitsLimit = 16
|
||||
|
||||
// Return the number of literals assigned to each bit size in the Huffman encoding
|
||||
//
|
||||
// This method is only called when list.length >= 3
|
||||
// The cases of 0, 1, and 2 literals are handled by special case code.
|
||||
//
|
||||
// list An array of the literals with non-zero frequencies
|
||||
// and their associated frequencies. The array is in order of increasing
|
||||
// frequency, and has as its last element a special element with frequency
|
||||
// MaxInt32
|
||||
// maxBits The maximum number of bits that should be used to encode any literal.
|
||||
// Must be less than 16.
|
||||
// return An integer array in which array[i] indicates the number of literals
|
||||
// that should be encoded in i bits.
|
||||
func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
|
||||
if maxBits >= maxBitsLimit {
|
||||
panic("flate: maxBits too large")
|
||||
}
|
||||
n := int32(len(list))
|
||||
list = list[0 : n+1]
|
||||
list[n] = maxNode()
|
||||
|
||||
// The tree can't have greater depth than n - 1, no matter what. This
|
||||
// saves a little bit of work in some small cases
|
||||
if maxBits > n-1 {
|
||||
maxBits = n - 1
|
||||
}
|
||||
|
||||
// Create information about each of the levels.
|
||||
// A bogus "Level 0" whose sole purpose is so that
|
||||
// level1.prev.needed==0. This makes level1.nextPairFreq
|
||||
// be a legitimate value that never gets chosen.
|
||||
var levels [maxBitsLimit]levelInfo
|
||||
// leafCounts[i] counts the number of literals at the left
|
||||
// of ancestors of the rightmost node at level i.
|
||||
// leafCounts[i][j] is the number of literals at the left
|
||||
// of the level j ancestor.
|
||||
var leafCounts [maxBitsLimit][maxBitsLimit]int32
|
||||
|
||||
for level := int32(1); level <= maxBits; level++ {
|
||||
// For every level, the first two items are the first two characters.
|
||||
// We initialize the levels as if we had already figured this out.
|
||||
levels[level] = levelInfo{
|
||||
level: level,
|
||||
lastFreq: list[1].freq,
|
||||
nextCharFreq: list[2].freq,
|
||||
nextPairFreq: list[0].freq + list[1].freq,
|
||||
}
|
||||
leafCounts[level][level] = 2
|
||||
if level == 1 {
|
||||
levels[level].nextPairFreq = math.MaxInt32
|
||||
}
|
||||
}
|
||||
|
||||
// We need a total of 2*n - 2 items at top level and have already generated 2.
|
||||
levels[maxBits].needed = 2*n - 4
|
||||
|
||||
level := maxBits
|
||||
for {
|
||||
l := &levels[level]
|
||||
if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
|
||||
// We've run out of both leafs and pairs.
|
||||
// End all calculations for this level.
|
||||
// To make sure we never come back to this level or any lower level,
|
||||
// set nextPairFreq impossibly large.
|
||||
l.needed = 0
|
||||
levels[level+1].nextPairFreq = math.MaxInt32
|
||||
level++
|
||||
continue
|
||||
}
|
||||
|
||||
prevFreq := l.lastFreq
|
||||
if l.nextCharFreq < l.nextPairFreq {
|
||||
// The next item on this row is a leaf node.
|
||||
n := leafCounts[level][level] + 1
|
||||
l.lastFreq = l.nextCharFreq
|
||||
// Lower leafCounts are the same of the previous node.
|
||||
leafCounts[level][level] = n
|
||||
l.nextCharFreq = list[n].freq
|
||||
} else {
|
||||
// The next item on this row is a pair from the previous row.
|
||||
// nextPairFreq isn't valid until we generate two
|
||||
// more values in the level below
|
||||
l.lastFreq = l.nextPairFreq
|
||||
// Take leaf counts from the lower level, except counts[level] remains the same.
|
||||
copy(leafCounts[level][:level], leafCounts[level-1][:level])
|
||||
levels[l.level-1].needed = 2
|
||||
}
|
||||
|
||||
if l.needed--; l.needed == 0 {
|
||||
// We've done everything we need to do for this level.
|
||||
// Continue calculating one level up. Fill in nextPairFreq
|
||||
// of that level with the sum of the two nodes we've just calculated on
|
||||
// this level.
|
||||
if l.level == maxBits {
|
||||
// All done!
|
||||
break
|
||||
}
|
||||
levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq
|
||||
level++
|
||||
} else {
|
||||
// If we stole from below, move down temporarily to replenish it.
|
||||
for levels[level-1].needed > 0 {
|
||||
level--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Somethings is wrong if at the end, the top level is null or hasn't used
|
||||
// all of the leaves.
|
||||
if leafCounts[maxBits][maxBits] != n {
|
||||
panic("leafCounts[maxBits][maxBits] != n")
|
||||
}
|
||||
|
||||
bitCount := h.bitCount[:maxBits+1]
|
||||
bits := 1
|
||||
counts := &leafCounts[maxBits]
|
||||
for level := maxBits; level > 0; level-- {
|
||||
// chain.leafCount gives the number of literals requiring at least "bits"
|
||||
// bits to encode.
|
||||
bitCount[bits] = counts[level] - counts[level-1]
|
||||
bits++
|
||||
}
|
||||
return bitCount
|
||||
}
|
||||
|
||||
// Look at the leaves and assign them a bit count and an encoding as specified
|
||||
// in RFC 1951 3.2.2
|
||||
func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) {
|
||||
code := uint16(0)
|
||||
for n, bits := range bitCount {
|
||||
code <<= 1
|
||||
if n == 0 || bits == 0 {
|
||||
continue
|
||||
}
|
||||
// The literals list[len(list)-bits] .. list[len(list)-bits]
|
||||
// are encoded using "bits" bits, and get the values
|
||||
// code, code + 1, .... The code values are
|
||||
// assigned in literal order (not frequency order).
|
||||
chunk := list[len(list)-int(bits):]
|
||||
|
||||
h.lns.sort(chunk)
|
||||
for _, node := range chunk {
|
||||
h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)}
|
||||
code++
|
||||
}
|
||||
list = list[0 : len(list)-int(bits)]
|
||||
}
|
||||
}
|
||||
|
||||
// Update this Huffman Code object to be the minimum code for the specified frequency count.
|
||||
//
|
||||
// freq An array of frequencies, in which frequency[i] gives the frequency of literal i.
|
||||
// maxBits The maximum number of bits to use for any literal.
|
||||
func (h *huffmanEncoder) generate(freq []int32, maxBits int32) {
|
||||
if h.freqcache == nil {
|
||||
// Allocate a reusable buffer with the longest possible frequency table.
|
||||
// Possible lengths are codegenCodeCount, offsetCodeCount and maxNumLit.
|
||||
// The largest of these is maxNumLit, so we allocate for that case.
|
||||
h.freqcache = make([]literalNode, maxNumLit+1)
|
||||
}
|
||||
list := h.freqcache[:len(freq)+1]
|
||||
// Number of non-zero literals
|
||||
count := 0
|
||||
// Set list to be the set of all non-zero literals and their frequencies
|
||||
for i, f := range freq {
|
||||
if f != 0 {
|
||||
list[count] = literalNode{uint16(i), f}
|
||||
count++
|
||||
} else {
|
||||
list[count] = literalNode{}
|
||||
h.codes[i].len = 0
|
||||
}
|
||||
}
|
||||
list[len(freq)] = literalNode{}
|
||||
|
||||
list = list[:count]
|
||||
if count <= 2 {
|
||||
// Handle the small cases here, because they are awkward for the general case code. With
|
||||
// two or fewer literals, everything has bit length 1.
|
||||
for i, node := range list {
|
||||
// "list" is in order of increasing literal value.
|
||||
h.codes[node.literal].set(uint16(i), 1)
|
||||
}
|
||||
return
|
||||
}
|
||||
h.lfs.sort(list)
|
||||
|
||||
// Get the number of literals for each bit count
|
||||
bitCount := h.bitCounts(list, maxBits)
|
||||
// And do the assignment
|
||||
h.assignEncodingAndSize(bitCount, list)
|
||||
}
|
||||
|
||||
type byLiteral []literalNode
|
||||
|
||||
func (s *byLiteral) sort(a []literalNode) {
|
||||
*s = byLiteral(a)
|
||||
sort.Sort(s)
|
||||
}
|
||||
|
||||
func (s byLiteral) Len() int { return len(s) }
|
||||
|
||||
func (s byLiteral) Less(i, j int) bool {
|
||||
return s[i].literal < s[j].literal
|
||||
}
|
||||
|
||||
func (s byLiteral) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
type byFreq []literalNode
|
||||
|
||||
func (s *byFreq) sort(a []literalNode) {
|
||||
*s = byFreq(a)
|
||||
sort.Sort(s)
|
||||
}
|
||||
|
||||
func (s byFreq) Len() int { return len(s) }
|
||||
|
||||
func (s byFreq) Less(i, j int) bool {
|
||||
if s[i].freq == s[j].freq {
|
||||
return s[i].literal < s[j].literal
|
||||
}
|
||||
return s[i].freq < s[j].freq
|
||||
}
|
||||
|
||||
func (s byFreq) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
@ -0,0 +1,868 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package flate implements the DEFLATE compressed data format, described in
|
||||
// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file
|
||||
// formats.
|
||||
package flate
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
maxCodeLen = 16 // max length of Huffman code
|
||||
maxCodeLenMask = 15 // mask for max length of Huffman code
|
||||
// The next three numbers come from the RFC section 3.2.7, with the
|
||||
// additional proviso in section 3.2.5 which implies that distance codes
|
||||
// 30 and 31 should never occur in compressed data.
|
||||
maxNumLit = 286
|
||||
maxNumDist = 30
|
||||
numCodes = 19 // number of codes in Huffman meta-code
|
||||
)
|
||||
|
||||
// Initialize the fixedHuffmanDecoder only once upon first use.
|
||||
var fixedOnce sync.Once
|
||||
var fixedHuffmanDecoder huffmanDecoder
|
||||
|
||||
// A CorruptInputError reports the presence of corrupt input at a given offset.
|
||||
type CorruptInputError int64
|
||||
|
||||
func (e CorruptInputError) Error() string {
|
||||
return "flate: corrupt input before offset " + strconv.FormatInt(int64(e), 10)
|
||||
}
|
||||
|
||||
// An InternalError reports an error in the flate code itself.
|
||||
type InternalError string
|
||||
|
||||
func (e InternalError) Error() string { return "flate: internal error: " + string(e) }
|
||||
|
||||
// A ReadError reports an error encountered while reading input.
|
||||
//
|
||||
// Deprecated: No longer returned.
|
||||
type ReadError struct {
|
||||
Offset int64 // byte offset where error occurred
|
||||
Err error // error returned by underlying Read
|
||||
}
|
||||
|
||||
func (e *ReadError) Error() string {
|
||||
return "flate: read error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
// A WriteError reports an error encountered while writing output.
|
||||
//
|
||||
// Deprecated: No longer returned.
|
||||
type WriteError struct {
|
||||
Offset int64 // byte offset where error occurred
|
||||
Err error // error returned by underlying Write
|
||||
}
|
||||
|
||||
func (e *WriteError) Error() string {
|
||||
return "flate: write error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
// Resetter resets a ReadCloser returned by NewReader or NewReaderDict to
|
||||
// to switch to a new underlying Reader. This permits reusing a ReadCloser
|
||||
// instead of allocating a new one.
|
||||
type Resetter interface {
|
||||
// Reset discards any buffered data and resets the Resetter as if it was
|
||||
// newly initialized with the given reader.
|
||||
Reset(r io.Reader, dict []byte) error
|
||||
}
|
||||
|
||||
// The data structure for decoding Huffman tables is based on that of
|
||||
// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits),
|
||||
// For codes smaller than the table width, there are multiple entries
|
||||
// (each combination of trailing bits has the same value). For codes
|
||||
// larger than the table width, the table contains a link to an overflow
|
||||
// table. The width of each entry in the link table is the maximum code
|
||||
// size minus the chunk width.
|
||||
//
|
||||
// Note that you can do a lookup in the table even without all bits
|
||||
// filled. Since the extra bits are zero, and the DEFLATE Huffman codes
|
||||
// have the property that shorter codes come before longer ones, the
|
||||
// bit length estimate in the result is a lower bound on the actual
|
||||
// number of bits.
|
||||
//
|
||||
// See the following:
|
||||
// http://www.gzip.org/algorithm.txt
|
||||
|
||||
// chunk & 15 is number of bits
|
||||
// chunk >> 4 is value, including table link
|
||||
|
||||
const (
|
||||
huffmanChunkBits = 9
|
||||
huffmanNumChunks = 1 << huffmanChunkBits
|
||||
huffmanCountMask = 15
|
||||
huffmanValueShift = 4
|
||||
)
|
||||
|
||||
type huffmanDecoder struct {
|
||||
min int // the minimum code length
|
||||
chunks *[huffmanNumChunks]uint32 // chunks as described above
|
||||
links [][]uint32 // overflow links
|
||||
linkMask uint32 // mask the width of the link table
|
||||
}
|
||||
|
||||
// Initialize Huffman decoding tables from array of code lengths.
|
||||
// Following this function, h is guaranteed to be initialized into a complete
|
||||
// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a
|
||||
// degenerate case where the tree has only a single symbol with length 1. Empty
|
||||
// trees are permitted.
|
||||
func (h *huffmanDecoder) init(bits []int) bool {
|
||||
// Sanity enables additional runtime tests during Huffman
|
||||
// table construction. It's intended to be used during
|
||||
// development to supplement the currently ad-hoc unit tests.
|
||||
const sanity = false
|
||||
|
||||
if h.chunks == nil {
|
||||
h.chunks = &[huffmanNumChunks]uint32{}
|
||||
}
|
||||
if h.min != 0 {
|
||||
*h = huffmanDecoder{chunks: h.chunks, links: h.links}
|
||||
}
|
||||
|
||||
// Count number of codes of each length,
|
||||
// compute min and max length.
|
||||
var count [maxCodeLen]int
|
||||
var min, max int
|
||||
for _, n := range bits {
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
if min == 0 || n < min {
|
||||
min = n
|
||||
}
|
||||
if n > max {
|
||||
max = n
|
||||
}
|
||||
count[n&maxCodeLenMask]++
|
||||
}
|
||||
|
||||
// Empty tree. The decompressor.huffSym function will fail later if the tree
|
||||
// is used. Technically, an empty tree is only valid for the HDIST tree and
|
||||
// not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree
|
||||
// is guaranteed to fail since it will attempt to use the tree to decode the
|
||||
// codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is
|
||||
// guaranteed to fail later since the compressed data section must be
|
||||
// composed of at least one symbol (the end-of-block marker).
|
||||
if max == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
code := 0
|
||||
var nextcode [maxCodeLen]int
|
||||
for i := min; i <= max; i++ {
|
||||
code <<= 1
|
||||
nextcode[i&maxCodeLenMask] = code
|
||||
code += count[i&maxCodeLenMask]
|
||||
}
|
||||
|
||||
// Check that the coding is complete (i.e., that we've
|
||||
// assigned all 2-to-the-max possible bit sequences).
|
||||
// Exception: To be compatible with zlib, we also need to
|
||||
// accept degenerate single-code codings. See also
|
||||
// TestDegenerateHuffmanCoding.
|
||||
if code != 1<<uint(max) && !(code == 1 && max == 1) {
|
||||
return false
|
||||
}
|
||||
|
||||
h.min = min
|
||||
chunks := h.chunks[:]
|
||||
for i := range chunks {
|
||||
chunks[i] = 0
|
||||
}
|
||||
|
||||
if max > huffmanChunkBits {
|
||||
numLinks := 1 << (uint(max) - huffmanChunkBits)
|
||||
h.linkMask = uint32(numLinks - 1)
|
||||
|
||||
// create link tables
|
||||
link := nextcode[huffmanChunkBits+1] >> 1
|
||||
if cap(h.links) < huffmanNumChunks-link {
|
||||
h.links = make([][]uint32, huffmanNumChunks-link)
|
||||
} else {
|
||||
h.links = h.links[:huffmanNumChunks-link]
|
||||
}
|
||||
for j := uint(link); j < huffmanNumChunks; j++ {
|
||||
reverse := int(reverseByte[j>>8]) | int(reverseByte[j&0xff])<<8
|
||||
reverse >>= uint(16 - huffmanChunkBits)
|
||||
off := j - uint(link)
|
||||
if sanity && h.chunks[reverse] != 0 {
|
||||
panic("impossible: overwriting existing chunk")
|
||||
}
|
||||
h.chunks[reverse] = uint32(off<<huffmanValueShift | (huffmanChunkBits + 1))
|
||||
if cap(h.links[off]) < numLinks {
|
||||
h.links[off] = make([]uint32, numLinks)
|
||||
} else {
|
||||
links := h.links[off][:0]
|
||||
h.links[off] = links[:numLinks]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
h.links = h.links[:0]
|
||||
}
|
||||
|
||||
for i, n := range bits {
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
code := nextcode[n]
|
||||
nextcode[n]++
|
||||
chunk := uint32(i<<huffmanValueShift | n)
|
||||
reverse := int(reverseByte[code>>8]) | int(reverseByte[code&0xff])<<8
|
||||
reverse >>= uint(16 - n)
|
||||
if n <= huffmanChunkBits {
|
||||
for off := reverse; off < len(h.chunks); off += 1 << uint(n) {
|
||||
// We should never need to overwrite
|
||||
// an existing chunk. Also, 0 is
|
||||
// never a valid chunk, because the
|
||||
// lower 4 "count" bits should be
|
||||
// between 1 and 15.
|
||||
if sanity && h.chunks[off] != 0 {
|
||||
panic("impossible: overwriting existing chunk")
|
||||
}
|
||||
h.chunks[off] = chunk
|
||||
}
|
||||
} else {
|
||||
j := reverse & (huffmanNumChunks - 1)
|
||||
if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 {
|
||||
// Longer codes should have been
|
||||
// associated with a link table above.
|
||||
panic("impossible: not an indirect chunk")
|
||||
}
|
||||
value := h.chunks[j] >> huffmanValueShift
|
||||
linktab := h.links[value]
|
||||
reverse >>= huffmanChunkBits
|
||||
for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) {
|
||||
if sanity && linktab[off] != 0 {
|
||||
panic("impossible: overwriting existing chunk")
|
||||
}
|
||||
linktab[off] = chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sanity {
|
||||
// Above we've sanity checked that we never overwrote
|
||||
// an existing entry. Here we additionally check that
|
||||
// we filled the tables completely.
|
||||
for i, chunk := range h.chunks {
|
||||
if chunk == 0 {
|
||||
// As an exception, in the degenerate
|
||||
// single-code case, we allow odd
|
||||
// chunks to be missing.
|
||||
if code == 1 && i%2 == 1 {
|
||||
continue
|
||||
}
|
||||
panic("impossible: missing chunk")
|
||||
}
|
||||
}
|
||||
for _, linktab := range h.links {
|
||||
for _, chunk := range linktab {
|
||||
if chunk == 0 {
|
||||
panic("impossible: missing chunk")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// The actual read interface needed by NewReader.
|
||||
// If the passed in io.Reader does not also have ReadByte,
|
||||
// the NewReader will introduce its own buffering.
|
||||
type Reader interface {
|
||||
io.Reader
|
||||
io.ByteReader
|
||||
}
|
||||
|
||||
// Decompress state.
|
||||
type decompressor struct {
|
||||
// Input source.
|
||||
r Reader
|
||||
roffset int64
|
||||
|
||||
// Input bits, in top of b.
|
||||
b uint32
|
||||
nb uint
|
||||
|
||||
// Huffman decoders for literal/length, distance.
|
||||
h1, h2 huffmanDecoder
|
||||
|
||||
// Length arrays used to define Huffman codes.
|
||||
bits *[maxNumLit + maxNumDist]int
|
||||
codebits *[numCodes]int
|
||||
|
||||
// Output history, buffer.
|
||||
dict dictDecoder
|
||||
|
||||
// Temporary buffer (avoids repeated allocation).
|
||||
buf [4]byte
|
||||
|
||||
// Next step in the decompression,
|
||||
// and decompression state.
|
||||
step func(*decompressor)
|
||||
stepState int
|
||||
final bool
|
||||
err error
|
||||
toRead []byte
|
||||
hl, hd *huffmanDecoder
|
||||
copyLen int
|
||||
copyDist int
|
||||
}
|
||||
|
||||
func (f *decompressor) nextBlock() {
|
||||
for f.nb < 1+2 {
|
||||
if f.err = f.moreBits(); f.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
f.final = f.b&1 == 1
|
||||
f.b >>= 1
|
||||
typ := f.b & 3
|
||||
f.b >>= 2
|
||||
f.nb -= 1 + 2
|
||||
switch typ {
|
||||
case 0:
|
||||
f.dataBlock()
|
||||
case 1:
|
||||
// compressed, fixed Huffman tables
|
||||
f.hl = &fixedHuffmanDecoder
|
||||
f.hd = nil
|
||||
f.huffmanBlock()
|
||||
case 2:
|
||||
// compressed, dynamic Huffman tables
|
||||
if f.err = f.readHuffman(); f.err != nil {
|
||||
break
|
||||
}
|
||||
f.hl = &f.h1
|
||||
f.hd = &f.h2
|
||||
f.huffmanBlock()
|
||||
default:
|
||||
// 3 is reserved.
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *decompressor) Read(b []byte) (int, error) {
|
||||
for {
|
||||
if len(f.toRead) > 0 {
|
||||
n := copy(b, f.toRead)
|
||||
f.toRead = f.toRead[n:]
|
||||
if len(f.toRead) == 0 {
|
||||
return n, f.err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
if f.err != nil {
|
||||
return 0, f.err
|
||||
}
|
||||
f.step(f)
|
||||
if f.err != nil && len(f.toRead) == 0 {
|
||||
f.toRead = f.dict.readFlush() // Flush what's left in case of error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Support the io.WriteTo interface for io.Copy and friends.
|
||||
func (f *decompressor) WriteTo(w io.Writer) (int64, error) {
|
||||
total := int64(0)
|
||||
flushed := false
|
||||
for {
|
||||
if len(f.toRead) > 0 {
|
||||
n, err := w.Write(f.toRead)
|
||||
total += int64(n)
|
||||
if err != nil {
|
||||
f.err = err
|
||||
return total, err
|
||||
}
|
||||
if n != len(f.toRead) {
|
||||
return total, io.ErrShortWrite
|
||||
}
|
||||
f.toRead = f.toRead[:0]
|
||||
}
|
||||
if f.err != nil && flushed {
|
||||
if f.err == io.EOF {
|
||||
return total, nil
|
||||
}
|
||||
return total, f.err
|
||||
}
|
||||
if f.err == nil {
|
||||
f.step(f)
|
||||
}
|
||||
if len(f.toRead) == 0 && f.err != nil && !flushed {
|
||||
f.toRead = f.dict.readFlush() // Flush what's left in case of error
|
||||
flushed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *decompressor) Close() error {
|
||||
if f.err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return f.err
|
||||
}
|
||||
|
||||
// RFC 1951 section 3.2.7.
|
||||
// Compression with dynamic Huffman codes
|
||||
|
||||
var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
|
||||
|
||||
func (f *decompressor) readHuffman() error {
|
||||
// HLIT[5], HDIST[5], HCLEN[4].
|
||||
for f.nb < 5+5+4 {
|
||||
if err := f.moreBits(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
nlit := int(f.b&0x1F) + 257
|
||||
if nlit > maxNumLit {
|
||||
return CorruptInputError(f.roffset)
|
||||
}
|
||||
f.b >>= 5
|
||||
ndist := int(f.b&0x1F) + 1
|
||||
if ndist > maxNumDist {
|
||||
return CorruptInputError(f.roffset)
|
||||
}
|
||||
f.b >>= 5
|
||||
nclen := int(f.b&0xF) + 4
|
||||
// numCodes is 19, so nclen is always valid.
|
||||
f.b >>= 4
|
||||
f.nb -= 5 + 5 + 4
|
||||
|
||||
// (HCLEN+4)*3 bits: code lengths in the magic codeOrder order.
|
||||
for i := 0; i < nclen; i++ {
|
||||
for f.nb < 3 {
|
||||
if err := f.moreBits(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
f.codebits[codeOrder[i]] = int(f.b & 0x7)
|
||||
f.b >>= 3
|
||||
f.nb -= 3
|
||||
}
|
||||
for i := nclen; i < len(codeOrder); i++ {
|
||||
f.codebits[codeOrder[i]] = 0
|
||||
}
|
||||
if !f.h1.init(f.codebits[0:]) {
|
||||
return CorruptInputError(f.roffset)
|
||||
}
|
||||
|
||||
// HLIT + 257 code lengths, HDIST + 1 code lengths,
|
||||
// using the code length Huffman code.
|
||||
for i, n := 0, nlit+ndist; i < n; {
|
||||
x, err := f.huffSym(&f.h1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if x < 16 {
|
||||
// Actual length.
|
||||
f.bits[i] = x
|
||||
i++
|
||||
continue
|
||||
}
|
||||
// Repeat previous length or zero.
|
||||
var rep int
|
||||
var nb uint
|
||||
var b int
|
||||
switch x {
|
||||
default:
|
||||
return InternalError("unexpected length code")
|
||||
case 16:
|
||||
rep = 3
|
||||
nb = 2
|
||||
if i == 0 {
|
||||
return CorruptInputError(f.roffset)
|
||||
}
|
||||
b = f.bits[i-1]
|
||||
case 17:
|
||||
rep = 3
|
||||
nb = 3
|
||||
b = 0
|
||||
case 18:
|
||||
rep = 11
|
||||
nb = 7
|
||||
b = 0
|
||||
}
|
||||
for f.nb < nb {
|
||||
if err := f.moreBits(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
rep += int(f.b & uint32(1<<nb-1))
|
||||
f.b >>= nb
|
||||
f.nb -= nb
|
||||
if i+rep > n {
|
||||
return CorruptInputError(f.roffset)
|
||||
}
|
||||
for j := 0; j < rep; j++ {
|
||||
f.bits[i] = b
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) {
|
||||
return CorruptInputError(f.roffset)
|
||||
}
|
||||
|
||||
// As an optimization, we can initialize the min bits to read at a time
|
||||
// for the HLIT tree to the length of the EOB marker since we know that
|
||||
// every block must terminate with one. This preserves the property that
|
||||
// we never read any extra bytes after the end of the DEFLATE stream.
|
||||
if f.h1.min < f.bits[endBlockMarker] {
|
||||
f.h1.min = f.bits[endBlockMarker]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a single Huffman block from f.
|
||||
// hl and hd are the Huffman states for the lit/length values
|
||||
// and the distance values, respectively. If hd == nil, using the
|
||||
// fixed distance encoding associated with fixed Huffman blocks.
|
||||
func (f *decompressor) huffmanBlock() {
|
||||
const (
|
||||
stateInit = iota // Zero value must be stateInit
|
||||
stateDict
|
||||
)
|
||||
|
||||
switch f.stepState {
|
||||
case stateInit:
|
||||
goto readLiteral
|
||||
case stateDict:
|
||||
goto copyHistory
|
||||
}
|
||||
|
||||
readLiteral:
|
||||
// Read literal and/or (length, distance) according to RFC section 3.2.3.
|
||||
{
|
||||
v, err := f.huffSym(f.hl)
|
||||
if err != nil {
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
var n uint // number of bits extra
|
||||
var length int
|
||||
switch {
|
||||
case v < 256:
|
||||
f.dict.writeByte(byte(v))
|
||||
if f.dict.availWrite() == 0 {
|
||||
f.toRead = f.dict.readFlush()
|
||||
f.step = (*decompressor).huffmanBlock
|
||||
f.stepState = stateInit
|
||||
return
|
||||
}
|
||||
goto readLiteral
|
||||
case v == 256:
|
||||
f.finishBlock()
|
||||
return
|
||||
// otherwise, reference to older data
|
||||
case v < 265:
|
||||
length = v - (257 - 3)
|
||||
n = 0
|
||||
case v < 269:
|
||||
length = v*2 - (265*2 - 11)
|
||||
n = 1
|
||||
case v < 273:
|
||||
length = v*4 - (269*4 - 19)
|
||||
n = 2
|
||||
case v < 277:
|
||||
length = v*8 - (273*8 - 35)
|
||||
n = 3
|
||||
case v < 281:
|
||||
length = v*16 - (277*16 - 67)
|
||||
n = 4
|
||||
case v < 285:
|
||||
length = v*32 - (281*32 - 131)
|
||||
n = 5
|
||||
case v < maxNumLit:
|
||||
length = 258
|
||||
n = 0
|
||||
default:
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
for f.nb < n {
|
||||
if err = f.moreBits(); err != nil {
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
length += int(f.b & uint32(1<<n-1))
|
||||
f.b >>= n
|
||||
f.nb -= n
|
||||
}
|
||||
|
||||
var dist int
|
||||
if f.hd == nil {
|
||||
for f.nb < 5 {
|
||||
if err = f.moreBits(); err != nil {
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
dist = int(reverseByte[(f.b&0x1F)<<3])
|
||||
f.b >>= 5
|
||||
f.nb -= 5
|
||||
} else {
|
||||
if dist, err = f.huffSym(f.hd); err != nil {
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case dist < 4:
|
||||
dist++
|
||||
case dist < maxNumDist:
|
||||
nb := uint(dist-2) >> 1
|
||||
// have 1 bit in bottom of dist, need nb more.
|
||||
extra := (dist & 1) << nb
|
||||
for f.nb < nb {
|
||||
if err = f.moreBits(); err != nil {
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
}
|
||||
extra |= int(f.b & uint32(1<<nb-1))
|
||||
f.b >>= nb
|
||||
f.nb -= nb
|
||||
dist = 1<<(nb+1) + 1 + extra
|
||||
default:
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
}
|
||||
|
||||
// No check on length; encoding can be prescient.
|
||||
if dist > f.dict.histSize() {
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
}
|
||||
|
||||
f.copyLen, f.copyDist = length, dist
|
||||
goto copyHistory
|
||||
}
|
||||
|
||||
copyHistory:
|
||||
// Perform a backwards copy according to RFC section 3.2.3.
|
||||
{
|
||||
cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen)
|
||||
if cnt == 0 {
|
||||
cnt = f.dict.writeCopy(f.copyDist, f.copyLen)
|
||||
}
|
||||
f.copyLen -= cnt
|
||||
|
||||
if f.dict.availWrite() == 0 || f.copyLen > 0 {
|
||||
f.toRead = f.dict.readFlush()
|
||||
f.step = (*decompressor).huffmanBlock // We need to continue this work
|
||||
f.stepState = stateDict
|
||||
return
|
||||
}
|
||||
goto readLiteral
|
||||
}
|
||||
}
|
||||
|
||||
// Copy a single uncompressed data block from input to output.
|
||||
func (f *decompressor) dataBlock() {
|
||||
// Uncompressed.
|
||||
// Discard current half-byte.
|
||||
f.nb = 0
|
||||
f.b = 0
|
||||
|
||||
// Length then ones-complement of length.
|
||||
nr, err := io.ReadFull(f.r, f.buf[0:4])
|
||||
f.roffset += int64(nr)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
n := int(f.buf[0]) | int(f.buf[1])<<8
|
||||
nn := int(f.buf[2]) | int(f.buf[3])<<8
|
||||
if uint16(nn) != uint16(^n) {
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
f.toRead = f.dict.readFlush()
|
||||
f.finishBlock()
|
||||
return
|
||||
}
|
||||
|
||||
f.copyLen = n
|
||||
f.copyData()
|
||||
}
|
||||
|
||||
// copyData copies f.copyLen bytes from the underlying reader into f.hist.
|
||||
// It pauses for reads when f.hist is full.
|
||||
func (f *decompressor) copyData() {
|
||||
buf := f.dict.writeSlice()
|
||||
if len(buf) > f.copyLen {
|
||||
buf = buf[:f.copyLen]
|
||||
}
|
||||
|
||||
cnt, err := io.ReadFull(f.r, buf)
|
||||
f.roffset += int64(cnt)
|
||||
f.copyLen -= cnt
|
||||
f.dict.writeMark(cnt)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
f.err = err
|
||||
return
|
||||
}
|
||||
|
||||
if f.dict.availWrite() == 0 || f.copyLen > 0 {
|
||||
f.toRead = f.dict.readFlush()
|
||||
f.step = (*decompressor).copyData
|
||||
return
|
||||
}
|
||||
f.finishBlock()
|
||||
}
|
||||
|
||||
func (f *decompressor) finishBlock() {
|
||||
if f.final {
|
||||
if f.dict.availRead() > 0 {
|
||||
f.toRead = f.dict.readFlush()
|
||||
}
|
||||
f.err = io.EOF
|
||||
}
|
||||
f.step = (*decompressor).nextBlock
|
||||
}
|
||||
|
||||
func (f *decompressor) moreBits() error {
|
||||
c, err := f.r.ReadByte()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
f.roffset++
|
||||
f.b |= uint32(c) << f.nb
|
||||
f.nb += 8
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the next Huffman-encoded symbol from f according to h.
|
||||
func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) {
|
||||
// Since a huffmanDecoder can be empty or be composed of a degenerate tree
|
||||
// with single element, huffSym must error on these two edge cases. In both
|
||||
// cases, the chunks slice will be 0 for the invalid sequence, leading it
|
||||
// satisfy the n == 0 check below.
|
||||
n := uint(h.min)
|
||||
for {
|
||||
for f.nb < n {
|
||||
if err := f.moreBits(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
chunk := h.chunks[f.b&(huffmanNumChunks-1)]
|
||||
n = uint(chunk & huffmanCountMask)
|
||||
if n > huffmanChunkBits {
|
||||
chunk = h.links[chunk>>huffmanValueShift][(f.b>>huffmanChunkBits)&h.linkMask]
|
||||
n = uint(chunk & huffmanCountMask)
|
||||
}
|
||||
if n <= f.nb {
|
||||
if n == 0 {
|
||||
f.err = CorruptInputError(f.roffset)
|
||||
return 0, f.err
|
||||
}
|
||||
f.b >>= n
|
||||
f.nb -= n
|
||||
return int(chunk >> huffmanValueShift), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeReader(r io.Reader) Reader {
|
||||
if rr, ok := r.(Reader); ok {
|
||||
return rr
|
||||
}
|
||||
return bufio.NewReader(r)
|
||||
}
|
||||
|
||||
func fixedHuffmanDecoderInit() {
|
||||
fixedOnce.Do(func() {
|
||||
// These come from the RFC section 3.2.6.
|
||||
var bits [288]int
|
||||
for i := 0; i < 144; i++ {
|
||||
bits[i] = 8
|
||||
}
|
||||
for i := 144; i < 256; i++ {
|
||||
bits[i] = 9
|
||||
}
|
||||
for i := 256; i < 280; i++ {
|
||||
bits[i] = 7
|
||||
}
|
||||
for i := 280; i < 288; i++ {
|
||||
bits[i] = 8
|
||||
}
|
||||
fixedHuffmanDecoder.init(bits[:])
|
||||
})
|
||||
}
|
||||
|
||||
func (f *decompressor) Reset(r io.Reader, dict []byte) error {
|
||||
*f = decompressor{
|
||||
r: makeReader(r),
|
||||
bits: f.bits,
|
||||
codebits: f.codebits,
|
||||
h1: f.h1,
|
||||
h2: f.h2,
|
||||
dict: f.dict,
|
||||
step: (*decompressor).nextBlock,
|
||||
}
|
||||
f.dict.init(maxMatchOffset, dict)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewReader returns a new ReadCloser that can be used
|
||||
// to read the uncompressed version of r.
|
||||
// If r does not also implement io.ByteReader,
|
||||
// the decompressor may read more data than necessary from r.
|
||||
// It is the caller's responsibility to call Close on the ReadCloser
|
||||
// when finished reading.
|
||||
//
|
||||
// The ReadCloser returned by NewReader also implements Resetter.
|
||||
func NewReader(r io.Reader) io.ReadCloser {
|
||||
fixedHuffmanDecoderInit()
|
||||
|
||||
var f decompressor
|
||||
f.r = makeReader(r)
|
||||
f.bits = new([maxNumLit + maxNumDist]int)
|
||||
f.codebits = new([numCodes]int)
|
||||
f.step = (*decompressor).nextBlock
|
||||
f.dict.init(maxMatchOffset, nil)
|
||||
return &f
|
||||
}
|
||||
|
||||
// NewReaderDict is like NewReader but initializes the reader
|
||||
// with a preset dictionary. The returned Reader behaves as if
|
||||
// the uncompressed data stream started with the given dictionary,
|
||||
// which has already been read. NewReaderDict is typically used
|
||||
// to read data compressed by NewWriterDict.
|
||||
//
|
||||
// The ReadCloser returned by NewReader also implements Resetter.
|
||||
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
|
||||
fixedHuffmanDecoderInit()
|
||||
|
||||
var f decompressor
|
||||
f.r = makeReader(r)
|
||||
f.bits = new([maxNumLit + maxNumDist]int)
|
||||
f.codebits = new([numCodes]int)
|
||||
f.step = (*decompressor).nextBlock
|
||||
f.dict.init(maxMatchOffset, dict)
|
||||
return &f
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
var reverseByte = [256]byte{
|
||||
0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0,
|
||||
0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
|
||||
0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8,
|
||||
0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
|
||||
0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4,
|
||||
0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
|
||||
0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec,
|
||||
0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
|
||||
0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2,
|
||||
0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
|
||||
0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea,
|
||||
0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
|
||||
0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6,
|
||||
0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
|
||||
0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee,
|
||||
0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
|
||||
0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1,
|
||||
0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
|
||||
0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9,
|
||||
0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
|
||||
0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5,
|
||||
0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
|
||||
0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed,
|
||||
0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
|
||||
0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3,
|
||||
0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
|
||||
0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb,
|
||||
0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
|
||||
0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7,
|
||||
0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
|
||||
0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef,
|
||||
0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
|
||||
}
|
||||
|
||||
func reverseUint16(v uint16) uint16 {
|
||||
return uint16(reverseByte[v>>8]) | uint16(reverseByte[v&0xFF])<<8
|
||||
}
|
||||
|
||||
func reverseBits(number uint16, bitLength byte) uint16 {
|
||||
return reverseUint16(number << uint8(16-bitLength))
|
||||
}
|
|
@ -0,0 +1,900 @@
|
|||
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
|
||||
// Modified for deflate by Klaus Post (c) 2015.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
// emitLiteral writes a literal chunk and returns the number of bytes written.
|
||||
func emitLiteral(dst *tokens, lit []byte) {
|
||||
ol := int(dst.n)
|
||||
for i, v := range lit {
|
||||
dst.tokens[(i+ol)&maxStoreBlockSize] = token(v)
|
||||
}
|
||||
dst.n += uint16(len(lit))
|
||||
}
|
||||
|
||||
// emitCopy writes a copy chunk and returns the number of bytes written.
|
||||
func emitCopy(dst *tokens, offset, length int) {
|
||||
dst.tokens[dst.n] = matchToken(uint32(length-3), uint32(offset-minOffsetSize))
|
||||
dst.n++
|
||||
}
|
||||
|
||||
type snappyEnc interface {
|
||||
Encode(dst *tokens, src []byte)
|
||||
Reset()
|
||||
}
|
||||
|
||||
func newSnappy(level int) snappyEnc {
|
||||
switch level {
|
||||
case 1:
|
||||
return &snappyL1{}
|
||||
case 2:
|
||||
return &snappyL2{snappyGen: snappyGen{cur: maxStoreBlockSize, prev: make([]byte, 0, maxStoreBlockSize)}}
|
||||
case 3:
|
||||
return &snappyL3{snappyGen: snappyGen{cur: maxStoreBlockSize, prev: make([]byte, 0, maxStoreBlockSize)}}
|
||||
case 4:
|
||||
return &snappyL4{snappyL3{snappyGen: snappyGen{cur: maxStoreBlockSize, prev: make([]byte, 0, maxStoreBlockSize)}}}
|
||||
default:
|
||||
panic("invalid level specified")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
tableBits = 14 // Bits used in the table
|
||||
tableSize = 1 << tableBits // Size of the table
|
||||
tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks.
|
||||
tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32.
|
||||
baseMatchOffset = 1 // The smallest match offset
|
||||
baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5
|
||||
maxMatchOffset = 1 << 15 // The largest match offset
|
||||
)
|
||||
|
||||
func load32(b []byte, i int) uint32 {
|
||||
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func load64(b []byte, i int) uint64 {
|
||||
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
func hash(u uint32) uint32 {
|
||||
return (u * 0x1e35a7bd) >> tableShift
|
||||
}
|
||||
|
||||
// snappyL1 encapsulates level 1 compression
|
||||
type snappyL1 struct{}
|
||||
|
||||
func (e *snappyL1) Reset() {}
|
||||
|
||||
func (e *snappyL1) Encode(dst *tokens, src []byte) {
|
||||
const (
|
||||
inputMargin = 16 - 1
|
||||
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||
)
|
||||
|
||||
// This check isn't in the Snappy implementation, but there, the caller
|
||||
// instead of the callee handles this case.
|
||||
if len(src) < minNonLiteralBlockSize {
|
||||
// We do not fill the token table.
|
||||
// This will be picked up by caller.
|
||||
dst.n = uint16(len(src))
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize the hash table.
|
||||
//
|
||||
// The table element type is uint16, as s < sLimit and sLimit < len(src)
|
||||
// and len(src) <= maxStoreBlockSize and maxStoreBlockSize == 65535.
|
||||
var table [tableSize]uint16
|
||||
|
||||
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||
// looking for copies.
|
||||
sLimit := len(src) - inputMargin
|
||||
|
||||
// nextEmit is where in src the next emitLiteral should start from.
|
||||
nextEmit := 0
|
||||
|
||||
// The encoded form must start with a literal, as there are no previous
|
||||
// bytes to copy, so we start looking for hash matches at s == 1.
|
||||
s := 1
|
||||
nextHash := hash(load32(src, s))
|
||||
|
||||
for {
|
||||
// Copied from the C++ snappy implementation:
|
||||
//
|
||||
// Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
// found, start looking only at every other byte. If 32 more bytes are
|
||||
// scanned (or skipped), look at every third byte, etc.. When a match
|
||||
// is found, immediately go back to looking at every byte. This is a
|
||||
// small loss (~5% performance, ~0.1% density) for compressible data
|
||||
// due to more bookkeeping, but for non-compressible data (such as
|
||||
// JPEG) it's a huge win since the compressor quickly "realizes" the
|
||||
// data is incompressible and doesn't bother looking for matches
|
||||
// everywhere.
|
||||
//
|
||||
// The "skip" variable keeps track of how many bytes there are since
|
||||
// the last match; dividing it by 32 (ie. right-shifting by five) gives
|
||||
// the number of bytes to move ahead for each iteration.
|
||||
skip := 32
|
||||
|
||||
nextS := s
|
||||
candidate := 0
|
||||
for {
|
||||
s = nextS
|
||||
bytesBetweenHashLookups := skip >> 5
|
||||
nextS = s + bytesBetweenHashLookups
|
||||
skip += bytesBetweenHashLookups
|
||||
if nextS > sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
candidate = int(table[nextHash&tableMask])
|
||||
table[nextHash&tableMask] = uint16(s)
|
||||
nextHash = hash(load32(src, nextS))
|
||||
if s-candidate <= maxMatchOffset && load32(src, s) == load32(src, candidate) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||
// them as literal bytes.
|
||||
emitLiteral(dst, src[nextEmit:s])
|
||||
|
||||
// Call emitCopy, and then see if another emitCopy could be our next
|
||||
// move. Repeat until we find no match for the input immediately after
|
||||
// what was consumed by the last emitCopy call.
|
||||
//
|
||||
// If we exit this loop normally then we need to call emitLiteral next,
|
||||
// though we don't yet know how big the literal will be. We handle that
|
||||
// by proceeding to the next iteration of the main loop. We also can
|
||||
// exit this loop via goto if we get close to exhausting the input.
|
||||
for {
|
||||
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||
// literal bytes prior to s.
|
||||
base := s
|
||||
|
||||
// Extend the 4-byte match as long as possible.
|
||||
//
|
||||
// This is an inlined version of Snappy's:
|
||||
// s = extendMatch(src, candidate+4, s+4)
|
||||
s += 4
|
||||
s1 := base + maxMatchLength
|
||||
if s1 > len(src) {
|
||||
s1 = len(src)
|
||||
}
|
||||
a := src[s:s1]
|
||||
b := src[candidate+4:]
|
||||
b = b[:len(a)]
|
||||
l := len(a)
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
l = i
|
||||
break
|
||||
}
|
||||
}
|
||||
s += l
|
||||
|
||||
// matchToken is flate's equivalent of Snappy's emitCopy.
|
||||
dst.tokens[dst.n] = matchToken(uint32(s-base-baseMatchLength), uint32(base-candidate-baseMatchOffset))
|
||||
dst.n++
|
||||
nextEmit = s
|
||||
if s >= sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
|
||||
// We could immediately start working at s now, but to improve
|
||||
// compression we first update the hash table at s-1 and at s. If
|
||||
// another emitCopy is not our next move, also calculate nextHash
|
||||
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||
// are faster as one load64 call (with some shifts) instead of
|
||||
// three load32 calls.
|
||||
x := load64(src, s-1)
|
||||
prevHash := hash(uint32(x >> 0))
|
||||
table[prevHash&tableMask] = uint16(s - 1)
|
||||
currHash := hash(uint32(x >> 8))
|
||||
candidate = int(table[currHash&tableMask])
|
||||
table[currHash&tableMask] = uint16(s)
|
||||
if s-candidate > maxMatchOffset || uint32(x>>8) != load32(src, candidate) {
|
||||
nextHash = hash(uint32(x >> 16))
|
||||
s++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emitRemainder:
|
||||
if nextEmit < len(src) {
|
||||
emitLiteral(dst, src[nextEmit:])
|
||||
}
|
||||
}
|
||||
|
||||
type tableEntry struct {
|
||||
val uint32
|
||||
offset int32
|
||||
}
|
||||
|
||||
func load3232(b []byte, i int32) uint32 {
|
||||
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
|
||||
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
|
||||
}
|
||||
|
||||
func load6432(b []byte, i int32) uint64 {
|
||||
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
|
||||
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
|
||||
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
|
||||
}
|
||||
|
||||
// snappyGen maintains the table for matches,
|
||||
// and the previous byte block for level 2.
|
||||
// This is the generic implementation.
|
||||
type snappyGen struct {
|
||||
prev []byte
|
||||
cur int32
|
||||
}
|
||||
|
||||
// snappyGen maintains the table for matches,
|
||||
// and the previous byte block for level 2.
|
||||
// This is the generic implementation.
|
||||
type snappyL2 struct {
|
||||
snappyGen
|
||||
table [tableSize]tableEntry
|
||||
}
|
||||
|
||||
// EncodeL2 uses a similar algorithm to level 1, but is capable
|
||||
// of matching across blocks giving better compression at a small slowdown.
|
||||
func (e *snappyL2) Encode(dst *tokens, src []byte) {
|
||||
const (
|
||||
inputMargin = 8 - 1
|
||||
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||
)
|
||||
|
||||
// Protect against e.cur wraparound.
|
||||
if e.cur > 1<<30 {
|
||||
for i := range e.table[:] {
|
||||
e.table[i] = tableEntry{}
|
||||
}
|
||||
e.cur = maxStoreBlockSize
|
||||
}
|
||||
|
||||
// This check isn't in the Snappy implementation, but there, the caller
|
||||
// instead of the callee handles this case.
|
||||
if len(src) < minNonLiteralBlockSize {
|
||||
// We do not fill the token table.
|
||||
// This will be picked up by caller.
|
||||
dst.n = uint16(len(src))
|
||||
e.cur += maxStoreBlockSize
|
||||
e.prev = e.prev[:0]
|
||||
return
|
||||
}
|
||||
|
||||
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||
// looking for copies.
|
||||
sLimit := int32(len(src) - inputMargin)
|
||||
|
||||
// nextEmit is where in src the next emitLiteral should start from.
|
||||
nextEmit := int32(0)
|
||||
s := int32(0)
|
||||
cv := load3232(src, s)
|
||||
nextHash := hash(cv)
|
||||
|
||||
for {
|
||||
// Copied from the C++ snappy implementation:
|
||||
//
|
||||
// Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
// found, start looking only at every other byte. If 32 more bytes are
|
||||
// scanned (or skipped), look at every third byte, etc.. When a match
|
||||
// is found, immediately go back to looking at every byte. This is a
|
||||
// small loss (~5% performance, ~0.1% density) for compressible data
|
||||
// due to more bookkeeping, but for non-compressible data (such as
|
||||
// JPEG) it's a huge win since the compressor quickly "realizes" the
|
||||
// data is incompressible and doesn't bother looking for matches
|
||||
// everywhere.
|
||||
//
|
||||
// The "skip" variable keeps track of how many bytes there are since
|
||||
// the last match; dividing it by 32 (ie. right-shifting by five) gives
|
||||
// the number of bytes to move ahead for each iteration.
|
||||
skip := int32(32)
|
||||
|
||||
nextS := s
|
||||
var candidate tableEntry
|
||||
for {
|
||||
s = nextS
|
||||
bytesBetweenHashLookups := skip >> 5
|
||||
nextS = s + bytesBetweenHashLookups
|
||||
skip += bytesBetweenHashLookups
|
||||
if nextS > sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
candidate = e.table[nextHash&tableMask]
|
||||
now := load3232(src, nextS)
|
||||
e.table[nextHash&tableMask] = tableEntry{offset: s + e.cur, val: cv}
|
||||
nextHash = hash(now)
|
||||
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset > maxMatchOffset || cv != candidate.val {
|
||||
// Out of range or not matched.
|
||||
cv = now
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||
// them as literal bytes.
|
||||
emitLiteral(dst, src[nextEmit:s])
|
||||
|
||||
// Call emitCopy, and then see if another emitCopy could be our next
|
||||
// move. Repeat until we find no match for the input immediately after
|
||||
// what was consumed by the last emitCopy call.
|
||||
//
|
||||
// If we exit this loop normally then we need to call emitLiteral next,
|
||||
// though we don't yet know how big the literal will be. We handle that
|
||||
// by proceeding to the next iteration of the main loop. We also can
|
||||
// exit this loop via goto if we get close to exhausting the input.
|
||||
for {
|
||||
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||
// literal bytes prior to s.
|
||||
|
||||
// Extend the 4-byte match as long as possible.
|
||||
//
|
||||
s += 4
|
||||
t := candidate.offset - e.cur + 4
|
||||
l := e.matchlen(s, t, src)
|
||||
|
||||
// matchToken is flate's equivalent of Snappy's emitCopy. (length,offset)
|
||||
dst.tokens[dst.n] = matchToken(uint32(l+4-baseMatchLength), uint32(s-t-baseMatchOffset))
|
||||
dst.n++
|
||||
s += l
|
||||
nextEmit = s
|
||||
if s >= sLimit {
|
||||
t += l
|
||||
// Index first pair after match end.
|
||||
if int(t+4) < len(src) && t > 0 {
|
||||
cv := load3232(src, t)
|
||||
e.table[hash(cv)&tableMask] = tableEntry{offset: t + e.cur, val: cv}
|
||||
}
|
||||
goto emitRemainder
|
||||
}
|
||||
|
||||
// We could immediately start working at s now, but to improve
|
||||
// compression we first update the hash table at s-1 and at s. If
|
||||
// another emitCopy is not our next move, also calculate nextHash
|
||||
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||
// are faster as one load64 call (with some shifts) instead of
|
||||
// three load32 calls.
|
||||
x := load6432(src, s-1)
|
||||
prevHash := hash(uint32(x))
|
||||
e.table[prevHash&tableMask] = tableEntry{offset: e.cur + s - 1, val: uint32(x)}
|
||||
x >>= 8
|
||||
currHash := hash(uint32(x))
|
||||
candidate = e.table[currHash&tableMask]
|
||||
e.table[currHash&tableMask] = tableEntry{offset: e.cur + s, val: uint32(x)}
|
||||
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset > maxMatchOffset || uint32(x) != candidate.val {
|
||||
cv = uint32(x >> 8)
|
||||
nextHash = hash(cv)
|
||||
s++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
emitRemainder:
|
||||
if int(nextEmit) < len(src) {
|
||||
emitLiteral(dst, src[nextEmit:])
|
||||
}
|
||||
e.cur += int32(len(src))
|
||||
e.prev = e.prev[:len(src)]
|
||||
copy(e.prev, src)
|
||||
}
|
||||
|
||||
type tableEntryPrev struct {
|
||||
Cur tableEntry
|
||||
Prev tableEntry
|
||||
}
|
||||
|
||||
// snappyL3
|
||||
type snappyL3 struct {
|
||||
snappyGen
|
||||
table [tableSize]tableEntryPrev
|
||||
}
|
||||
|
||||
// Encode uses a similar algorithm to level 2, will check up to two candidates.
|
||||
func (e *snappyL3) Encode(dst *tokens, src []byte) {
|
||||
const (
|
||||
inputMargin = 8 - 1
|
||||
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||
)
|
||||
|
||||
// Protect against e.cur wraparound.
|
||||
if e.cur > 1<<30 {
|
||||
for i := range e.table[:] {
|
||||
e.table[i] = tableEntryPrev{}
|
||||
}
|
||||
e.snappyGen = snappyGen{cur: maxStoreBlockSize, prev: e.prev[:0]}
|
||||
}
|
||||
|
||||
// This check isn't in the Snappy implementation, but there, the caller
|
||||
// instead of the callee handles this case.
|
||||
if len(src) < minNonLiteralBlockSize {
|
||||
// We do not fill the token table.
|
||||
// This will be picked up by caller.
|
||||
dst.n = uint16(len(src))
|
||||
e.cur += maxStoreBlockSize
|
||||
e.prev = e.prev[:0]
|
||||
return
|
||||
}
|
||||
|
||||
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||
// looking for copies.
|
||||
sLimit := int32(len(src) - inputMargin)
|
||||
|
||||
// nextEmit is where in src the next emitLiteral should start from.
|
||||
nextEmit := int32(0)
|
||||
s := int32(0)
|
||||
cv := load3232(src, s)
|
||||
nextHash := hash(cv)
|
||||
|
||||
for {
|
||||
// Copied from the C++ snappy implementation:
|
||||
//
|
||||
// Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
// found, start looking only at every other byte. If 32 more bytes are
|
||||
// scanned (or skipped), look at every third byte, etc.. When a match
|
||||
// is found, immediately go back to looking at every byte. This is a
|
||||
// small loss (~5% performance, ~0.1% density) for compressible data
|
||||
// due to more bookkeeping, but for non-compressible data (such as
|
||||
// JPEG) it's a huge win since the compressor quickly "realizes" the
|
||||
// data is incompressible and doesn't bother looking for matches
|
||||
// everywhere.
|
||||
//
|
||||
// The "skip" variable keeps track of how many bytes there are since
|
||||
// the last match; dividing it by 32 (ie. right-shifting by five) gives
|
||||
// the number of bytes to move ahead for each iteration.
|
||||
skip := int32(32)
|
||||
|
||||
nextS := s
|
||||
var candidate tableEntry
|
||||
for {
|
||||
s = nextS
|
||||
bytesBetweenHashLookups := skip >> 5
|
||||
nextS = s + bytesBetweenHashLookups
|
||||
skip += bytesBetweenHashLookups
|
||||
if nextS > sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
candidates := e.table[nextHash&tableMask]
|
||||
now := load3232(src, nextS)
|
||||
e.table[nextHash&tableMask] = tableEntryPrev{Prev: candidates.Cur, Cur: tableEntry{offset: s + e.cur, val: cv}}
|
||||
nextHash = hash(now)
|
||||
|
||||
// Check both candidates
|
||||
candidate = candidates.Cur
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset <= maxMatchOffset {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// We only check if value mismatches.
|
||||
// Offset will always be invalid in other cases.
|
||||
candidate = candidates.Prev
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset <= maxMatchOffset {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
cv = now
|
||||
}
|
||||
|
||||
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||
// them as literal bytes.
|
||||
emitLiteral(dst, src[nextEmit:s])
|
||||
|
||||
// Call emitCopy, and then see if another emitCopy could be our next
|
||||
// move. Repeat until we find no match for the input immediately after
|
||||
// what was consumed by the last emitCopy call.
|
||||
//
|
||||
// If we exit this loop normally then we need to call emitLiteral next,
|
||||
// though we don't yet know how big the literal will be. We handle that
|
||||
// by proceeding to the next iteration of the main loop. We also can
|
||||
// exit this loop via goto if we get close to exhausting the input.
|
||||
for {
|
||||
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||
// literal bytes prior to s.
|
||||
|
||||
// Extend the 4-byte match as long as possible.
|
||||
//
|
||||
s += 4
|
||||
t := candidate.offset - e.cur + 4
|
||||
l := e.matchlen(s, t, src)
|
||||
|
||||
// matchToken is flate's equivalent of Snappy's emitCopy. (length,offset)
|
||||
dst.tokens[dst.n] = matchToken(uint32(l+4-baseMatchLength), uint32(s-t-baseMatchOffset))
|
||||
dst.n++
|
||||
s += l
|
||||
nextEmit = s
|
||||
if s >= sLimit {
|
||||
t += l
|
||||
// Index first pair after match end.
|
||||
if int(t+4) < len(src) && t > 0 {
|
||||
cv := load3232(src, t)
|
||||
nextHash = hash(cv)
|
||||
e.table[nextHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[nextHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + t, val: cv},
|
||||
}
|
||||
}
|
||||
goto emitRemainder
|
||||
}
|
||||
|
||||
// We could immediately start working at s now, but to improve
|
||||
// compression we first update the hash table at s-3 to s. If
|
||||
// another emitCopy is not our next move, also calculate nextHash
|
||||
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||
// are faster as one load64 call (with some shifts) instead of
|
||||
// three load32 calls.
|
||||
x := load6432(src, s-3)
|
||||
prevHash := hash(uint32(x))
|
||||
e.table[prevHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[prevHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + s - 3, val: uint32(x)},
|
||||
}
|
||||
x >>= 8
|
||||
prevHash = hash(uint32(x))
|
||||
|
||||
e.table[prevHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[prevHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + s - 2, val: uint32(x)},
|
||||
}
|
||||
x >>= 8
|
||||
prevHash = hash(uint32(x))
|
||||
|
||||
e.table[prevHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[prevHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + s - 1, val: uint32(x)},
|
||||
}
|
||||
x >>= 8
|
||||
currHash := hash(uint32(x))
|
||||
candidates := e.table[currHash&tableMask]
|
||||
cv = uint32(x)
|
||||
e.table[currHash&tableMask] = tableEntryPrev{
|
||||
Prev: candidates.Cur,
|
||||
Cur: tableEntry{offset: s + e.cur, val: cv},
|
||||
}
|
||||
|
||||
// Check both candidates
|
||||
candidate = candidates.Cur
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset <= maxMatchOffset {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// We only check if value mismatches.
|
||||
// Offset will always be invalid in other cases.
|
||||
candidate = candidates.Prev
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset <= maxMatchOffset {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
cv = uint32(x >> 8)
|
||||
nextHash = hash(cv)
|
||||
s++
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
emitRemainder:
|
||||
if int(nextEmit) < len(src) {
|
||||
emitLiteral(dst, src[nextEmit:])
|
||||
}
|
||||
e.cur += int32(len(src))
|
||||
e.prev = e.prev[:len(src)]
|
||||
copy(e.prev, src)
|
||||
}
|
||||
|
||||
// snappyL4
|
||||
type snappyL4 struct {
|
||||
snappyL3
|
||||
}
|
||||
|
||||
// Encode uses a similar algorithm to level 3,
|
||||
// but will check up to two candidates if first isn't long enough.
|
||||
func (e *snappyL4) Encode(dst *tokens, src []byte) {
|
||||
const (
|
||||
inputMargin = 8 - 3
|
||||
minNonLiteralBlockSize = 1 + 1 + inputMargin
|
||||
matchLenGood = 12
|
||||
)
|
||||
|
||||
// Protect against e.cur wraparound.
|
||||
if e.cur > 1<<30 {
|
||||
for i := range e.table[:] {
|
||||
e.table[i] = tableEntryPrev{}
|
||||
}
|
||||
e.snappyGen = snappyGen{cur: maxStoreBlockSize, prev: e.prev[:0]}
|
||||
}
|
||||
|
||||
// This check isn't in the Snappy implementation, but there, the caller
|
||||
// instead of the callee handles this case.
|
||||
if len(src) < minNonLiteralBlockSize {
|
||||
// We do not fill the token table.
|
||||
// This will be picked up by caller.
|
||||
dst.n = uint16(len(src))
|
||||
e.cur += maxStoreBlockSize
|
||||
e.prev = e.prev[:0]
|
||||
return
|
||||
}
|
||||
|
||||
// sLimit is when to stop looking for offset/length copies. The inputMargin
|
||||
// lets us use a fast path for emitLiteral in the main loop, while we are
|
||||
// looking for copies.
|
||||
sLimit := int32(len(src) - inputMargin)
|
||||
|
||||
// nextEmit is where in src the next emitLiteral should start from.
|
||||
nextEmit := int32(0)
|
||||
s := int32(0)
|
||||
cv := load3232(src, s)
|
||||
nextHash := hash(cv)
|
||||
|
||||
for {
|
||||
// Copied from the C++ snappy implementation:
|
||||
//
|
||||
// Heuristic match skipping: If 32 bytes are scanned with no matches
|
||||
// found, start looking only at every other byte. If 32 more bytes are
|
||||
// scanned (or skipped), look at every third byte, etc.. When a match
|
||||
// is found, immediately go back to looking at every byte. This is a
|
||||
// small loss (~5% performance, ~0.1% density) for compressible data
|
||||
// due to more bookkeeping, but for non-compressible data (such as
|
||||
// JPEG) it's a huge win since the compressor quickly "realizes" the
|
||||
// data is incompressible and doesn't bother looking for matches
|
||||
// everywhere.
|
||||
//
|
||||
// The "skip" variable keeps track of how many bytes there are since
|
||||
// the last match; dividing it by 32 (ie. right-shifting by five) gives
|
||||
// the number of bytes to move ahead for each iteration.
|
||||
skip := int32(32)
|
||||
|
||||
nextS := s
|
||||
var candidate tableEntry
|
||||
var candidateAlt tableEntry
|
||||
for {
|
||||
s = nextS
|
||||
bytesBetweenHashLookups := skip >> 5
|
||||
nextS = s + bytesBetweenHashLookups
|
||||
skip += bytesBetweenHashLookups
|
||||
if nextS > sLimit {
|
||||
goto emitRemainder
|
||||
}
|
||||
candidates := e.table[nextHash&tableMask]
|
||||
now := load3232(src, nextS)
|
||||
e.table[nextHash&tableMask] = tableEntryPrev{Prev: candidates.Cur, Cur: tableEntry{offset: s + e.cur, val: cv}}
|
||||
nextHash = hash(now)
|
||||
|
||||
// Check both candidates
|
||||
candidate = candidates.Cur
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset < maxMatchOffset {
|
||||
offset = s - (candidates.Prev.offset - e.cur)
|
||||
if cv == candidates.Prev.val && offset < maxMatchOffset {
|
||||
candidateAlt = candidates.Prev
|
||||
}
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// We only check if value mismatches.
|
||||
// Offset will always be invalid in other cases.
|
||||
candidate = candidates.Prev
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset < maxMatchOffset {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
cv = now
|
||||
}
|
||||
|
||||
// A 4-byte match has been found. We'll later see if more than 4 bytes
|
||||
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
|
||||
// them as literal bytes.
|
||||
emitLiteral(dst, src[nextEmit:s])
|
||||
|
||||
// Call emitCopy, and then see if another emitCopy could be our next
|
||||
// move. Repeat until we find no match for the input immediately after
|
||||
// what was consumed by the last emitCopy call.
|
||||
//
|
||||
// If we exit this loop normally then we need to call emitLiteral next,
|
||||
// though we don't yet know how big the literal will be. We handle that
|
||||
// by proceeding to the next iteration of the main loop. We also can
|
||||
// exit this loop via goto if we get close to exhausting the input.
|
||||
for {
|
||||
// Invariant: we have a 4-byte match at s, and no need to emit any
|
||||
// literal bytes prior to s.
|
||||
|
||||
// Extend the 4-byte match as long as possible.
|
||||
//
|
||||
s += 4
|
||||
t := candidate.offset - e.cur + 4
|
||||
l := e.matchlen(s, t, src)
|
||||
// Try alternative candidate if match length < matchLenGood.
|
||||
if l < matchLenGood-4 && candidateAlt.offset != 0 {
|
||||
t2 := candidateAlt.offset - e.cur + 4
|
||||
l2 := e.matchlen(s, t2, src)
|
||||
if l2 > l {
|
||||
l = l2
|
||||
t = t2
|
||||
}
|
||||
}
|
||||
// matchToken is flate's equivalent of Snappy's emitCopy. (length,offset)
|
||||
dst.tokens[dst.n] = matchToken(uint32(l+4-baseMatchLength), uint32(s-t-baseMatchOffset))
|
||||
dst.n++
|
||||
s += l
|
||||
nextEmit = s
|
||||
if s >= sLimit {
|
||||
t += l
|
||||
// Index first pair after match end.
|
||||
if int(t+4) < len(src) && t > 0 {
|
||||
cv := load3232(src, t)
|
||||
nextHash = hash(cv)
|
||||
e.table[nextHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[nextHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + t, val: cv},
|
||||
}
|
||||
}
|
||||
goto emitRemainder
|
||||
}
|
||||
|
||||
// We could immediately start working at s now, but to improve
|
||||
// compression we first update the hash table at s-3 to s. If
|
||||
// another emitCopy is not our next move, also calculate nextHash
|
||||
// at s+1. At least on GOARCH=amd64, these three hash calculations
|
||||
// are faster as one load64 call (with some shifts) instead of
|
||||
// three load32 calls.
|
||||
x := load6432(src, s-3)
|
||||
prevHash := hash(uint32(x))
|
||||
e.table[prevHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[prevHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + s - 3, val: uint32(x)},
|
||||
}
|
||||
x >>= 8
|
||||
prevHash = hash(uint32(x))
|
||||
|
||||
e.table[prevHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[prevHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + s - 2, val: uint32(x)},
|
||||
}
|
||||
x >>= 8
|
||||
prevHash = hash(uint32(x))
|
||||
|
||||
e.table[prevHash&tableMask] = tableEntryPrev{
|
||||
Prev: e.table[prevHash&tableMask].Cur,
|
||||
Cur: tableEntry{offset: e.cur + s - 1, val: uint32(x)},
|
||||
}
|
||||
x >>= 8
|
||||
currHash := hash(uint32(x))
|
||||
candidates := e.table[currHash&tableMask]
|
||||
cv = uint32(x)
|
||||
e.table[currHash&tableMask] = tableEntryPrev{
|
||||
Prev: candidates.Cur,
|
||||
Cur: tableEntry{offset: s + e.cur, val: cv},
|
||||
}
|
||||
|
||||
// Check both candidates
|
||||
candidate = candidates.Cur
|
||||
candidateAlt = tableEntry{}
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset <= maxMatchOffset {
|
||||
offset = s - (candidates.Prev.offset - e.cur)
|
||||
if cv == candidates.Prev.val && offset <= maxMatchOffset {
|
||||
candidateAlt = candidates.Prev
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// We only check if value mismatches.
|
||||
// Offset will always be invalid in other cases.
|
||||
candidate = candidates.Prev
|
||||
if cv == candidate.val {
|
||||
offset := s - (candidate.offset - e.cur)
|
||||
if offset <= maxMatchOffset {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
cv = uint32(x >> 8)
|
||||
nextHash = hash(cv)
|
||||
s++
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
emitRemainder:
|
||||
if int(nextEmit) < len(src) {
|
||||
emitLiteral(dst, src[nextEmit:])
|
||||
}
|
||||
e.cur += int32(len(src))
|
||||
e.prev = e.prev[:len(src)]
|
||||
copy(e.prev, src)
|
||||
}
|
||||
|
||||
func (e *snappyGen) matchlen(s, t int32, src []byte) int32 {
|
||||
s1 := int(s) + maxMatchLength - 4
|
||||
if s1 > len(src) {
|
||||
s1 = len(src)
|
||||
}
|
||||
|
||||
// If we are inside the current block
|
||||
if t >= 0 {
|
||||
b := src[t:]
|
||||
a := src[s:s1]
|
||||
b = b[:len(a)]
|
||||
// Extend the match to be as long as possible.
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return int32(i)
|
||||
}
|
||||
}
|
||||
return int32(len(a))
|
||||
}
|
||||
|
||||
// We found a match in the previous block.
|
||||
tp := int32(len(e.prev)) + t
|
||||
if tp < 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Extend the match to be as long as possible.
|
||||
a := src[s:s1]
|
||||
b := e.prev[tp:]
|
||||
if len(b) > len(a) {
|
||||
b = b[:len(a)]
|
||||
}
|
||||
a = a[:len(b)]
|
||||
for i := range b {
|
||||
if a[i] != b[i] {
|
||||
return int32(i)
|
||||
}
|
||||
}
|
||||
|
||||
// If we reached our limit, we matched everything we are
|
||||
// allowed to in the previous block and we return.
|
||||
n := int32(len(b))
|
||||
if int(s+n) == s1 {
|
||||
return n
|
||||
}
|
||||
|
||||
// Continue looking for more matches in the current block.
|
||||
a = src[s+n : s1]
|
||||
b = src[:len(a)]
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return int32(i) + n
|
||||
}
|
||||
}
|
||||
return int32(len(a)) + n
|
||||
}
|
||||
|
||||
// Reset the encoding table.
|
||||
func (e *snappyGen) Reset() {
|
||||
e.prev = e.prev[:0]
|
||||
e.cur += maxMatchOffset
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package flate
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
// 2 bits: type 0 = literal 1=EOF 2=Match 3=Unused
|
||||
// 8 bits: xlength = length - MIN_MATCH_LENGTH
|
||||
// 22 bits xoffset = offset - MIN_OFFSET_SIZE, or literal
|
||||
lengthShift = 22
|
||||
offsetMask = 1<<lengthShift - 1
|
||||
typeMask = 3 << 30
|
||||
literalType = 0 << 30
|
||||
matchType = 1 << 30
|
||||
)
|
||||
|
||||
// The length code for length X (MIN_MATCH_LENGTH <= X <= MAX_MATCH_LENGTH)
|
||||
// is lengthCodes[length - MIN_MATCH_LENGTH]
|
||||
var lengthCodes = [...]uint32{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 8,
|
||||
9, 9, 10, 10, 11, 11, 12, 12, 12, 12,
|
||||
13, 13, 13, 13, 14, 14, 14, 14, 15, 15,
|
||||
15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
|
||||
17, 17, 17, 17, 17, 17, 17, 17, 18, 18,
|
||||
18, 18, 18, 18, 18, 18, 19, 19, 19, 19,
|
||||
19, 19, 19, 19, 20, 20, 20, 20, 20, 20,
|
||||
20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
|
||||
21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
|
||||
21, 21, 21, 21, 21, 21, 22, 22, 22, 22,
|
||||
22, 22, 22, 22, 22, 22, 22, 22, 22, 22,
|
||||
22, 22, 23, 23, 23, 23, 23, 23, 23, 23,
|
||||
23, 23, 23, 23, 23, 23, 23, 23, 24, 24,
|
||||
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||
24, 24, 24, 24, 24, 24, 24, 24, 24, 24,
|
||||
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||
25, 25, 25, 25, 25, 25, 25, 25, 25, 25,
|
||||
25, 25, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||
26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
|
||||
26, 26, 26, 26, 27, 27, 27, 27, 27, 27,
|
||||
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||
27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
|
||||
27, 27, 27, 27, 27, 28,
|
||||
}
|
||||
|
||||
var offsetCodes = [...]uint32{
|
||||
0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7,
|
||||
8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
|
||||
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
|
||||
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||
12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
|
||||
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
|
||||
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||
14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
|
||||
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
|
||||
}
|
||||
|
||||
type token uint32
|
||||
|
||||
type tokens struct {
|
||||
tokens [maxStoreBlockSize + 1]token
|
||||
n uint16 // Must be able to contain maxStoreBlockSize
|
||||
}
|
||||
|
||||
// Convert a literal into a literal token.
|
||||
func literalToken(literal uint32) token { return token(literalType + literal) }
|
||||
|
||||
// Convert a < xlength, xoffset > pair into a match token.
|
||||
func matchToken(xlength uint32, xoffset uint32) token {
|
||||
return token(matchType + xlength<<lengthShift + xoffset)
|
||||
}
|
||||
|
||||
func matchTokend(xlength uint32, xoffset uint32) token {
|
||||
if xlength > maxMatchLength || xoffset > maxMatchOffset {
|
||||
panic(fmt.Sprintf("Invalid match: len: %d, offset: %d\n", xlength, xoffset))
|
||||
return token(matchType)
|
||||
}
|
||||
return token(matchType + xlength<<lengthShift + xoffset)
|
||||
}
|
||||
|
||||
// Returns the type of a token
|
||||
func (t token) typ() uint32 { return uint32(t) & typeMask }
|
||||
|
||||
// Returns the literal of a literal token
|
||||
func (t token) literal() uint32 { return uint32(t - literalType) }
|
||||
|
||||
// Returns the extra offset of a match token
|
||||
func (t token) offset() uint32 { return uint32(t) & offsetMask }
|
||||
|
||||
func (t token) length() uint32 { return uint32((t - matchType) >> lengthShift) }
|
||||
|
||||
func lengthCode(len uint32) uint32 { return lengthCodes[len] }
|
||||
|
||||
// Returns the offset code corresponding to a specific offset
|
||||
func offsetCode(off uint32) uint32 {
|
||||
if off < uint32(len(offsetCodes)) {
|
||||
return offsetCodes[off]
|
||||
} else if off>>7 < uint32(len(offsetCodes)) {
|
||||
return offsetCodes[off>>7] + 14
|
||||
} else {
|
||||
return offsetCodes[off>>14] + 28
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,22 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Klaus Post
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
pgzip
|
||||
=====
|
||||
|
||||
Go parallel gzip compression/decompression. This is a fully gzip compatible drop in replacement for "compress/gzip".
|
||||
|
||||
This will split compression into blocks that are compressed in parallel.
|
||||
This can be useful for compressing big amounts of data. The output is a standard gzip file.
|
||||
|
||||
The gzip decompression is modified so it decompresses ahead of the current reader.
|
||||
This means that reads will be non-blocking if the decompressor can keep ahead of your code reading from it.
|
||||
CRC calculation also takes place in a separate goroutine.
|
||||
|
||||
You should only use this if you are (de)compressing big amounts of data,
|
||||
say **more than 1MB** at the time, otherwise you will not see any benefit,
|
||||
and it will likely be faster to use the internal gzip library
|
||||
or [this package](https://github.com/klauspost/compress).
|
||||
|
||||
It is important to note that this library creates and reads *standard gzip files*.
|
||||
You do not have to match the compressor/decompressor to get the described speedups,
|
||||
and the gzip files are fully compatible with other gzip readers/writers.
|
||||
|
||||
A golang variant of this is [bgzf](https://godoc.org/github.com/biogo/hts/bgzf),
|
||||
which has the same feature, as well as seeking in the resulting file.
|
||||
The only drawback is a slightly bigger overhead compared to this and pure gzip.
|
||||
See a comparison below.
|
||||
|
||||
[![GoDoc][1]][2] [![Build Status][3]][4]
|
||||
|
||||
[1]: https://godoc.org/github.com/klauspost/pgzip?status.svg
|
||||
[2]: https://godoc.org/github.com/klauspost/pgzip
|
||||
[3]: https://travis-ci.org/klauspost/pgzip.svg
|
||||
[4]: https://travis-ci.org/klauspost/pgzip
|
||||
|
||||
Installation
|
||||
====
|
||||
```go get github.com/klauspost/pgzip/...```
|
||||
|
||||
You might need to get/update the dependencies:
|
||||
|
||||
```
|
||||
go get -u github.com/klauspost/compress
|
||||
go get -u github.com/klauspost/crc32
|
||||
```
|
||||
|
||||
Usage
|
||||
====
|
||||
[Godoc Doumentation](https://godoc.org/github.com/klauspost/pgzip)
|
||||
|
||||
To use as a replacement for gzip, exchange
|
||||
|
||||
```import "compress/gzip"```
|
||||
with
|
||||
```import gzip "github.com/klauspost/pgzip"```.
|
||||
|
||||
# Changes
|
||||
|
||||
* Oct 6, 2016: Fixed an issue if the destination writer returned an error.
|
||||
* Oct 6, 2016: Better buffer reuse, should now generate less garbage.
|
||||
* Oct 6, 2016: Output does not change based on write sizes.
|
||||
* Dec 8, 2015: Decoder now supports the io.WriterTo interface, giving a speedup and less GC pressure.
|
||||
* Oct 9, 2015: Reduced allocations by ~35 by using sync.Pool. ~15% overall speedup.
|
||||
|
||||
Changes in [github.com/klauspost/compress](https://github.com/klauspost/compress#changelog) are also carried over, so see that for more changes.
|
||||
|
||||
## Compression
|
||||
The simplest way to use this is to simply do the same as you would when using [compress/gzip](http://golang.org/pkg/compress/gzip).
|
||||
|
||||
To change the block size, use the added (*pgzip.Writer).SetConcurrency(blockSize, blocks int) function. With this you can control the approximate size of your blocks, as well as how many you want to be processing in parallel. Default values for this is SetConcurrency(250000, 16), meaning blocks are split at 250000 bytes and up to 16 blocks can be processing at once before the writer blocks.
|
||||
|
||||
|
||||
Example:
|
||||
```
|
||||
var b bytes.Buffer
|
||||
w := gzip.NewWriter(&b)
|
||||
w.SetConcurrency(100000, 10)
|
||||
w.Write([]byte("hello, world\n"))
|
||||
w.Close()
|
||||
```
|
||||
|
||||
To get any performance gains, you should at least be compressing more than 1 megabyte of data at the time.
|
||||
|
||||
You should at least have a block size of 100k and at least a number of blocks that match the number of cores your would like to utilize, but about twice the number of blocks would be the best.
|
||||
|
||||
Another side effect of this is, that it is likely to speed up your other code, since writes to the compressor only blocks if the compressor is already compressing the number of blocks you have specified. This also means you don't have worry about buffering input to the compressor.
|
||||
|
||||
## Decompression
|
||||
|
||||
Decompression works similar to compression. That means that you simply call pgzip the same way as you would call [compress/gzip](http://golang.org/pkg/compress/gzip).
|
||||
|
||||
The only difference is that if you want to specify your own readahead, you have to use `pgzip.NewReaderN(r io.Reader, blockSize, blocks int)` to get a reader with your custom blocksizes. The `blockSize` is the size of each block decoded, and `blocks` is the maximum number of blocks that is decoded ahead.
|
||||
|
||||
See [Example on playground](http://play.golang.org/p/uHv1B5NbDh)
|
||||
|
||||
Performance
|
||||
====
|
||||
## Compression
|
||||
|
||||
See my blog post in [Benchmarks of Golang Gzip](https://blog.klauspost.com/go-gzipdeflate-benchmarks/).
|
||||
|
||||
Compression cost is usually about 0.2% with default settings with a block size of 250k.
|
||||
|
||||
Example with GOMAXPROC set to 8 (quad core with 8 hyperthreads)
|
||||
|
||||
Content is [Matt Mahoneys 10GB corpus](http://mattmahoney.net/dc/10gb.html). Compression level 6.
|
||||
|
||||
Compressor | MB/sec | speedup | size | size overhead (lower=better)
|
||||
------------|----------|---------|------|---------
|
||||
[gzip](http://golang.org/pkg/compress/gzip) (golang) | 7.21MB/s | 1.0x | 4786608902 | 0%
|
||||
[gzip](http://github.com/klauspost/compress/gzip) (klauspost) | 10.98MB/s | 1.52x | 4781331645 | -0.11%
|
||||
[pgzip](https://github.com/klauspost/pgzip) (klauspost) | 50.76MB/s|7.04x | 4784121440 | -0.052%
|
||||
[bgzf](https://godoc.org/github.com/biogo/hts/bgzf) (biogo) | 38.65MB/s | 5.36x | 4924899484 | 2.889%
|
||||
[pargzip](https://godoc.org/github.com/golang/build/pargzip) (builder) | 32.00MB/s | 4.44x | 4791226567 | 0.096%
|
||||
|
||||
pgzip also contains a [linear time compression](https://github.com/klauspost/compress#linear-time-compression) mode, that will allow compression at ~150MB per core per second, independent of the content.
|
||||
|
||||
See the [complete sheet](https://docs.google.com/spreadsheets/d/1nuNE2nPfuINCZJRMt6wFWhKpToF95I47XjSsc-1rbPQ/edit?usp=sharing) for different content types and compression settings.
|
||||
|
||||
## Decompression
|
||||
|
||||
The decompression speedup is there because it allows you to do other work while the decompression is taking place.
|
||||
|
||||
In the example above, the numbers are as follows on a 4 CPU machine:
|
||||
|
||||
Decompressor | Time | Speedup
|
||||
-------------|------|--------
|
||||
[gzip](http://golang.org/pkg/compress/gzip) (golang) | 1m28.85s | 0%
|
||||
[pgzip](https://github.com/klauspost/pgzip) (golang) | 43.48s | 104%
|
||||
|
||||
But wait, since gzip decompression is inherently singlethreaded (aside from CRC calculation) how can it be more than 100% faster? Because pgzip due to its design also acts as a buffer. When using unbuffered gzip, you are also waiting for io when you are decompressing. If the gzip decoder can keep up, it will always have data ready for your reader, and you will not be waiting for input to the gzip decompressor to complete.
|
||||
|
||||
This is pretty much an optimal situation for pgzip, but it reflects most common usecases for CPU intensive gzip usage.
|
||||
|
||||
I haven't included [bgzf](https://godoc.org/github.com/biogo/hts/bgzf) in this comparison, since it only can decompress files created by a compatible encoder, and therefore cannot be considered a generic gzip decompressor. But if you are able to compress your files with a bgzf compatible program, you can expect it to scale beyond 100%.
|
||||
|
||||
# License
|
||||
This contains large portions of code from the go repository - see GO_LICENSE for more information. The changes are released under MIT License. See LICENSE for more information.
|
|
@ -0,0 +1,7 @@
|
|||
test:
|
||||
pre:
|
||||
- go vet ./...
|
||||
|
||||
override:
|
||||
- go test -v -cpu=1,2,4 .
|
||||
- go test -v -cpu=2 -race -short .
|
|
@ -0,0 +1,573 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package pgzip implements reading and writing of gzip format compressed files,
|
||||
// as specified in RFC 1952.
|
||||
//
|
||||
// This is a drop in replacement for "compress/gzip".
|
||||
// This will split compression into blocks that are compressed in parallel.
|
||||
// This can be useful for compressing big amounts of data.
|
||||
// The gzip decompression has not been modified, but remains in the package,
|
||||
// so you can use it as a complete replacement for "compress/gzip".
|
||||
//
|
||||
// See more at https://github.com/klauspost/pgzip
|
||||
package pgzip
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"hash"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/flate"
|
||||
"github.com/klauspost/crc32"
|
||||
)
|
||||
|
||||
const (
|
||||
gzipID1 = 0x1f
|
||||
gzipID2 = 0x8b
|
||||
gzipDeflate = 8
|
||||
flagText = 1 << 0
|
||||
flagHdrCrc = 1 << 1
|
||||
flagExtra = 1 << 2
|
||||
flagName = 1 << 3
|
||||
flagComment = 1 << 4
|
||||
)
|
||||
|
||||
func makeReader(r io.Reader) flate.Reader {
|
||||
if rr, ok := r.(flate.Reader); ok {
|
||||
return rr
|
||||
}
|
||||
return bufio.NewReader(r)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrChecksum is returned when reading GZIP data that has an invalid checksum.
|
||||
ErrChecksum = errors.New("gzip: invalid checksum")
|
||||
// ErrHeader is returned when reading GZIP data that has an invalid header.
|
||||
ErrHeader = errors.New("gzip: invalid header")
|
||||
)
|
||||
|
||||
// The gzip file stores a header giving metadata about the compressed file.
|
||||
// That header is exposed as the fields of the Writer and Reader structs.
|
||||
type Header struct {
|
||||
Comment string // comment
|
||||
Extra []byte // "extra data"
|
||||
ModTime time.Time // modification time
|
||||
Name string // file name
|
||||
OS byte // operating system type
|
||||
}
|
||||
|
||||
// A Reader is an io.Reader that can be read to retrieve
|
||||
// uncompressed data from a gzip-format compressed file.
|
||||
//
|
||||
// In general, a gzip file can be a concatenation of gzip files,
|
||||
// each with its own header. Reads from the Reader
|
||||
// return the concatenation of the uncompressed data of each.
|
||||
// Only the first header is recorded in the Reader fields.
|
||||
//
|
||||
// Gzip files store a length and checksum of the uncompressed data.
|
||||
// The Reader will return a ErrChecksum when Read
|
||||
// reaches the end of the uncompressed data if it does not
|
||||
// have the expected length or checksum. Clients should treat data
|
||||
// returned by Read as tentative until they receive the io.EOF
|
||||
// marking the end of the data.
|
||||
type Reader struct {
|
||||
Header
|
||||
r flate.Reader
|
||||
decompressor io.ReadCloser
|
||||
digest hash.Hash32
|
||||
size uint32
|
||||
flg byte
|
||||
buf [512]byte
|
||||
err error
|
||||
closeErr chan error
|
||||
multistream bool
|
||||
|
||||
readAhead chan read
|
||||
roff int // read offset
|
||||
current []byte
|
||||
closeReader chan struct{}
|
||||
lastBlock bool
|
||||
blockSize int
|
||||
blocks int
|
||||
|
||||
activeRA bool // Indication if readahead is active
|
||||
mu sync.Mutex // Lock for above
|
||||
|
||||
blockPool chan []byte
|
||||
}
|
||||
|
||||
type read struct {
|
||||
b []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// NewReader creates a new Reader reading the given reader.
|
||||
// The implementation buffers input and may read more data than necessary from r.
|
||||
// It is the caller's responsibility to call Close on the Reader when done.
|
||||
func NewReader(r io.Reader) (*Reader, error) {
|
||||
z := new(Reader)
|
||||
z.blocks = defaultBlocks
|
||||
z.blockSize = defaultBlockSize
|
||||
z.r = makeReader(r)
|
||||
z.digest = crc32.NewIEEE()
|
||||
z.multistream = true
|
||||
z.blockPool = make(chan []byte, z.blocks)
|
||||
for i := 0; i < z.blocks; i++ {
|
||||
z.blockPool <- make([]byte, z.blockSize)
|
||||
}
|
||||
if err := z.readHeader(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return z, nil
|
||||
}
|
||||
|
||||
// NewReaderN creates a new Reader reading the given reader.
|
||||
// The implementation buffers input and may read more data than necessary from r.
|
||||
// It is the caller's responsibility to call Close on the Reader when done.
|
||||
//
|
||||
// With this you can control the approximate size of your blocks,
|
||||
// as well as how many blocks you want to have prefetched.
|
||||
//
|
||||
// Default values for this is blockSize = 250000, blocks = 16,
|
||||
// meaning up to 16 blocks of maximum 250000 bytes will be
|
||||
// prefetched.
|
||||
func NewReaderN(r io.Reader, blockSize, blocks int) (*Reader, error) {
|
||||
z := new(Reader)
|
||||
z.blocks = blocks
|
||||
z.blockSize = blockSize
|
||||
z.r = makeReader(r)
|
||||
z.digest = crc32.NewIEEE()
|
||||
z.multistream = true
|
||||
|
||||
// Account for too small values
|
||||
if z.blocks <= 0 {
|
||||
z.blocks = defaultBlocks
|
||||
}
|
||||
if z.blockSize <= 512 {
|
||||
z.blockSize = defaultBlockSize
|
||||
}
|
||||
z.blockPool = make(chan []byte, z.blocks)
|
||||
for i := 0; i < z.blocks; i++ {
|
||||
z.blockPool <- make([]byte, z.blockSize)
|
||||
}
|
||||
if err := z.readHeader(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return z, nil
|
||||
}
|
||||
|
||||
// Reset discards the Reader z's state and makes it equivalent to the
|
||||
// result of its original state from NewReader, but reading from r instead.
|
||||
// This permits reusing a Reader rather than allocating a new one.
|
||||
func (z *Reader) Reset(r io.Reader) error {
|
||||
z.killReadAhead()
|
||||
z.r = makeReader(r)
|
||||
z.digest = crc32.NewIEEE()
|
||||
z.size = 0
|
||||
z.err = nil
|
||||
z.multistream = true
|
||||
|
||||
// Account for uninitialized values
|
||||
if z.blocks <= 0 {
|
||||
z.blocks = defaultBlocks
|
||||
}
|
||||
if z.blockSize <= 512 {
|
||||
z.blockSize = defaultBlockSize
|
||||
}
|
||||
|
||||
if z.blockPool == nil {
|
||||
z.blockPool = make(chan []byte, z.blocks)
|
||||
for i := 0; i < z.blocks; i++ {
|
||||
z.blockPool <- make([]byte, z.blockSize)
|
||||
}
|
||||
}
|
||||
|
||||
return z.readHeader(true)
|
||||
}
|
||||
|
||||
// Multistream controls whether the reader supports multistream files.
|
||||
//
|
||||
// If enabled (the default), the Reader expects the input to be a sequence
|
||||
// of individually gzipped data streams, each with its own header and
|
||||
// trailer, ending at EOF. The effect is that the concatenation of a sequence
|
||||
// of gzipped files is treated as equivalent to the gzip of the concatenation
|
||||
// of the sequence. This is standard behavior for gzip readers.
|
||||
//
|
||||
// Calling Multistream(false) disables this behavior; disabling the behavior
|
||||
// can be useful when reading file formats that distinguish individual gzip
|
||||
// data streams or mix gzip data streams with other data streams.
|
||||
// In this mode, when the Reader reaches the end of the data stream,
|
||||
// Read returns io.EOF. If the underlying reader implements io.ByteReader,
|
||||
// it will be left positioned just after the gzip stream.
|
||||
// To start the next stream, call z.Reset(r) followed by z.Multistream(false).
|
||||
// If there is no next stream, z.Reset(r) will return io.EOF.
|
||||
func (z *Reader) Multistream(ok bool) {
|
||||
z.multistream = ok
|
||||
}
|
||||
|
||||
// GZIP (RFC 1952) is little-endian, unlike ZLIB (RFC 1950).
|
||||
func get4(p []byte) uint32 {
|
||||
return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
|
||||
}
|
||||
|
||||
func (z *Reader) readString() (string, error) {
|
||||
var err error
|
||||
needconv := false
|
||||
for i := 0; ; i++ {
|
||||
if i >= len(z.buf) {
|
||||
return "", ErrHeader
|
||||
}
|
||||
z.buf[i], err = z.r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if z.buf[i] > 0x7f {
|
||||
needconv = true
|
||||
}
|
||||
if z.buf[i] == 0 {
|
||||
// GZIP (RFC 1952) specifies that strings are NUL-terminated ISO 8859-1 (Latin-1).
|
||||
if needconv {
|
||||
s := make([]rune, 0, i)
|
||||
for _, v := range z.buf[0:i] {
|
||||
s = append(s, rune(v))
|
||||
}
|
||||
return string(s), nil
|
||||
}
|
||||
return string(z.buf[0:i]), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (z *Reader) read2() (uint32, error) {
|
||||
_, err := io.ReadFull(z.r, z.buf[0:2])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return uint32(z.buf[0]) | uint32(z.buf[1])<<8, nil
|
||||
}
|
||||
|
||||
func (z *Reader) readHeader(save bool) error {
|
||||
z.killReadAhead()
|
||||
|
||||
_, err := io.ReadFull(z.r, z.buf[0:10])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if z.buf[0] != gzipID1 || z.buf[1] != gzipID2 || z.buf[2] != gzipDeflate {
|
||||
return ErrHeader
|
||||
}
|
||||
z.flg = z.buf[3]
|
||||
if save {
|
||||
z.ModTime = time.Unix(int64(get4(z.buf[4:8])), 0)
|
||||
// z.buf[8] is xfl, ignored
|
||||
z.OS = z.buf[9]
|
||||
}
|
||||
z.digest.Reset()
|
||||
z.digest.Write(z.buf[0:10])
|
||||
|
||||
if z.flg&flagExtra != 0 {
|
||||
n, err := z.read2()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := make([]byte, n)
|
||||
if _, err = io.ReadFull(z.r, data); err != nil {
|
||||
return err
|
||||
}
|
||||
if save {
|
||||
z.Extra = data
|
||||
}
|
||||
}
|
||||
|
||||
var s string
|
||||
if z.flg&flagName != 0 {
|
||||
if s, err = z.readString(); err != nil {
|
||||
return err
|
||||
}
|
||||
if save {
|
||||
z.Name = s
|
||||
}
|
||||
}
|
||||
|
||||
if z.flg&flagComment != 0 {
|
||||
if s, err = z.readString(); err != nil {
|
||||
return err
|
||||
}
|
||||
if save {
|
||||
z.Comment = s
|
||||
}
|
||||
}
|
||||
|
||||
if z.flg&flagHdrCrc != 0 {
|
||||
n, err := z.read2()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sum := z.digest.Sum32() & 0xFFFF
|
||||
if n != sum {
|
||||
return ErrHeader
|
||||
}
|
||||
}
|
||||
|
||||
z.digest.Reset()
|
||||
z.decompressor = flate.NewReader(z.r)
|
||||
z.doReadAhead()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (z *Reader) killReadAhead() error {
|
||||
z.mu.Lock()
|
||||
defer z.mu.Unlock()
|
||||
if z.activeRA {
|
||||
if z.closeReader != nil {
|
||||
close(z.closeReader)
|
||||
}
|
||||
|
||||
// Wait for decompressor to be closed and return error, if any.
|
||||
e, ok := <-z.closeErr
|
||||
z.activeRA = false
|
||||
if !ok {
|
||||
// Channel is closed, so if there was any error it has already been returned.
|
||||
return nil
|
||||
}
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Starts readahead.
|
||||
// Will return on error (including io.EOF)
|
||||
// or when z.closeReader is closed.
|
||||
func (z *Reader) doReadAhead() {
|
||||
z.mu.Lock()
|
||||
defer z.mu.Unlock()
|
||||
z.activeRA = true
|
||||
|
||||
if z.blocks <= 0 {
|
||||
z.blocks = defaultBlocks
|
||||
}
|
||||
if z.blockSize <= 512 {
|
||||
z.blockSize = defaultBlockSize
|
||||
}
|
||||
ra := make(chan read, z.blocks)
|
||||
z.readAhead = ra
|
||||
closeReader := make(chan struct{}, 0)
|
||||
z.closeReader = closeReader
|
||||
z.lastBlock = false
|
||||
closeErr := make(chan error, 1)
|
||||
z.closeErr = closeErr
|
||||
z.size = 0
|
||||
z.roff = 0
|
||||
z.current = nil
|
||||
decomp := z.decompressor
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
closeErr <- decomp.Close()
|
||||
close(closeErr)
|
||||
close(ra)
|
||||
}()
|
||||
|
||||
// We hold a local reference to digest, since
|
||||
// it way be changed by reset.
|
||||
digest := z.digest
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
var buf []byte
|
||||
select {
|
||||
case buf = <-z.blockPool:
|
||||
case <-closeReader:
|
||||
return
|
||||
}
|
||||
buf = buf[0:z.blockSize]
|
||||
// Try to fill the buffer
|
||||
n, err := io.ReadFull(decomp, buf)
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
if n > 0 {
|
||||
err = nil
|
||||
} else {
|
||||
// If we got zero bytes, we need to establish if
|
||||
// we reached end of stream or truncated stream.
|
||||
_, err = decomp.Read([]byte{})
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if n < len(buf) {
|
||||
buf = buf[0:n]
|
||||
}
|
||||
wg.Wait()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
digest.Write(buf)
|
||||
wg.Done()
|
||||
}()
|
||||
z.size += uint32(n)
|
||||
|
||||
// If we return any error, out digest must be ready
|
||||
if err != nil {
|
||||
wg.Wait()
|
||||
}
|
||||
select {
|
||||
case z.readAhead <- read{b: buf, err: err}:
|
||||
case <-closeReader:
|
||||
// Sent on close, we don't care about the next results
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (z *Reader) Read(p []byte) (n int, err error) {
|
||||
if z.err != nil {
|
||||
return 0, z.err
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
for {
|
||||
if len(z.current) == 0 && !z.lastBlock {
|
||||
read := <-z.readAhead
|
||||
|
||||
if read.err != nil {
|
||||
// If not nil, the reader will have exited
|
||||
z.closeReader = nil
|
||||
|
||||
if read.err != io.EOF {
|
||||
z.err = read.err
|
||||
return
|
||||
}
|
||||
if read.err == io.EOF {
|
||||
z.lastBlock = true
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
z.current = read.b
|
||||
z.roff = 0
|
||||
}
|
||||
avail := z.current[z.roff:]
|
||||
if len(p) >= len(avail) {
|
||||
// If len(p) >= len(current), return all content of current
|
||||
n = copy(p, avail)
|
||||
z.blockPool <- z.current
|
||||
z.current = nil
|
||||
if z.lastBlock {
|
||||
err = io.EOF
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// We copy as much as there is space for
|
||||
n = copy(p, avail)
|
||||
z.roff += n
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Finished file; check checksum + size.
|
||||
if _, err := io.ReadFull(z.r, z.buf[0:8]); err != nil {
|
||||
z.err = err
|
||||
return 0, err
|
||||
}
|
||||
crc32, isize := get4(z.buf[0:4]), get4(z.buf[4:8])
|
||||
sum := z.digest.Sum32()
|
||||
if sum != crc32 || isize != z.size {
|
||||
z.err = ErrChecksum
|
||||
return 0, z.err
|
||||
}
|
||||
|
||||
// File is ok; should we attempt reading one more?
|
||||
if !z.multistream {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Is there another?
|
||||
if err = z.readHeader(false); err != nil {
|
||||
z.err = err
|
||||
return
|
||||
}
|
||||
|
||||
// Yes. Reset and read from it.
|
||||
return z.Read(p)
|
||||
}
|
||||
|
||||
func (z *Reader) WriteTo(w io.Writer) (n int64, err error) {
|
||||
total := int64(0)
|
||||
for {
|
||||
if z.err != nil {
|
||||
return total, z.err
|
||||
}
|
||||
// We write both to output and digest.
|
||||
for {
|
||||
// Read from input
|
||||
read := <-z.readAhead
|
||||
if read.err != nil {
|
||||
// If not nil, the reader will have exited
|
||||
z.closeReader = nil
|
||||
|
||||
if read.err != io.EOF {
|
||||
z.err = read.err
|
||||
return total, z.err
|
||||
}
|
||||
if read.err == io.EOF {
|
||||
z.lastBlock = true
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
// Write what we got
|
||||
n, err := w.Write(read.b)
|
||||
if n != len(read.b) {
|
||||
return total, io.ErrShortWrite
|
||||
}
|
||||
total += int64(n)
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
// Put block back
|
||||
z.blockPool <- read.b
|
||||
if z.lastBlock {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Finished file; check checksum + size.
|
||||
if _, err := io.ReadFull(z.r, z.buf[0:8]); err != nil {
|
||||
z.err = err
|
||||
return total, err
|
||||
}
|
||||
crc32, isize := get4(z.buf[0:4]), get4(z.buf[4:8])
|
||||
sum := z.digest.Sum32()
|
||||
if sum != crc32 || isize != z.size {
|
||||
z.err = ErrChecksum
|
||||
return total, z.err
|
||||
}
|
||||
// File is ok; should we attempt reading one more?
|
||||
if !z.multistream {
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// Is there another?
|
||||
err = z.readHeader(false)
|
||||
if err == io.EOF {
|
||||
return total, nil
|
||||
}
|
||||
if err != nil {
|
||||
z.err = err
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Reader. It does not close the underlying io.Reader.
|
||||
func (z *Reader) Close() error {
|
||||
return z.killReadAhead()
|
||||
}
|
|
@ -0,0 +1,501 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package pgzip
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/flate"
|
||||
"github.com/klauspost/crc32"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBlockSize = 256 << 10
|
||||
tailSize = 16384
|
||||
defaultBlocks = 16
|
||||
)
|
||||
|
||||
// These constants are copied from the flate package, so that code that imports
|
||||
// "compress/gzip" does not also have to import "compress/flate".
|
||||
const (
|
||||
NoCompression = flate.NoCompression
|
||||
BestSpeed = flate.BestSpeed
|
||||
BestCompression = flate.BestCompression
|
||||
DefaultCompression = flate.DefaultCompression
|
||||
ConstantCompression = flate.ConstantCompression
|
||||
HuffmanOnly = flate.HuffmanOnly
|
||||
)
|
||||
|
||||
// A Writer is an io.WriteCloser.
|
||||
// Writes to a Writer are compressed and written to w.
|
||||
type Writer struct {
|
||||
Header
|
||||
w io.Writer
|
||||
level int
|
||||
wroteHeader bool
|
||||
blockSize int
|
||||
blocks int
|
||||
currentBuffer []byte
|
||||
prevTail []byte
|
||||
digest hash.Hash32
|
||||
size int
|
||||
closed bool
|
||||
buf [10]byte
|
||||
errMu sync.RWMutex
|
||||
err error
|
||||
pushedErr chan struct{}
|
||||
results chan result
|
||||
dictFlatePool sync.Pool
|
||||
dstPool sync.Pool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type result struct {
|
||||
result chan []byte
|
||||
notifyWritten chan struct{}
|
||||
}
|
||||
|
||||
// Use SetConcurrency to finetune the concurrency level if needed.
|
||||
//
|
||||
// With this you can control the approximate size of your blocks,
|
||||
// as well as how many you want to be processing in parallel.
|
||||
//
|
||||
// Default values for this is SetConcurrency(250000, 16),
|
||||
// meaning blocks are split at 250000 bytes and up to 16 blocks
|
||||
// can be processing at once before the writer blocks.
|
||||
func (z *Writer) SetConcurrency(blockSize, blocks int) error {
|
||||
if blockSize <= tailSize {
|
||||
return fmt.Errorf("gzip: block size cannot be less than or equal to %d", tailSize)
|
||||
}
|
||||
if blocks <= 0 {
|
||||
return errors.New("gzip: blocks cannot be zero or less")
|
||||
}
|
||||
if blockSize == z.blockSize && blocks == z.blocks {
|
||||
return nil
|
||||
}
|
||||
z.blockSize = blockSize
|
||||
z.results = make(chan result, blocks)
|
||||
z.blocks = blocks
|
||||
z.dstPool = sync.Pool{New: func() interface{} { return make([]byte, 0, blockSize+(blockSize)>>4) }}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewWriter returns a new Writer.
|
||||
// Writes to the returned writer are compressed and written to w.
|
||||
//
|
||||
// It is the caller's responsibility to call Close on the WriteCloser when done.
|
||||
// Writes may be buffered and not flushed until Close.
|
||||
//
|
||||
// Callers that wish to set the fields in Writer.Header must do so before
|
||||
// the first call to Write or Close. The Comment and Name header fields are
|
||||
// UTF-8 strings in Go, but the underlying format requires NUL-terminated ISO
|
||||
// 8859-1 (Latin-1). NUL or non-Latin-1 runes in those strings will lead to an
|
||||
// error on Write.
|
||||
func NewWriter(w io.Writer) *Writer {
|
||||
z, _ := NewWriterLevel(w, DefaultCompression)
|
||||
return z
|
||||
}
|
||||
|
||||
// NewWriterLevel is like NewWriter but specifies the compression level instead
|
||||
// of assuming DefaultCompression.
|
||||
//
|
||||
// The compression level can be DefaultCompression, NoCompression, or any
|
||||
// integer value between BestSpeed and BestCompression inclusive. The error
|
||||
// returned will be nil if the level is valid.
|
||||
func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
|
||||
if level < ConstantCompression || level > BestCompression {
|
||||
return nil, fmt.Errorf("gzip: invalid compression level: %d", level)
|
||||
}
|
||||
z := new(Writer)
|
||||
z.SetConcurrency(defaultBlockSize, defaultBlocks)
|
||||
z.init(w, level)
|
||||
return z, nil
|
||||
}
|
||||
|
||||
// This function must be used by goroutines to set an
|
||||
// error condition, since z.err access is restricted
|
||||
// to the callers goruotine.
|
||||
func (z *Writer) pushError(err error) {
|
||||
z.errMu.Lock()
|
||||
if z.err != nil {
|
||||
z.errMu.Unlock()
|
||||
return
|
||||
}
|
||||
z.err = err
|
||||
close(z.pushedErr)
|
||||
z.errMu.Unlock()
|
||||
}
|
||||
|
||||
func (z *Writer) init(w io.Writer, level int) {
|
||||
z.wg.Wait()
|
||||
digest := z.digest
|
||||
if digest != nil {
|
||||
digest.Reset()
|
||||
} else {
|
||||
digest = crc32.NewIEEE()
|
||||
}
|
||||
z.Header = Header{OS: 255}
|
||||
z.w = w
|
||||
z.level = level
|
||||
z.digest = digest
|
||||
z.pushedErr = make(chan struct{}, 0)
|
||||
z.results = make(chan result, z.blocks)
|
||||
z.err = nil
|
||||
z.closed = false
|
||||
z.Comment = ""
|
||||
z.Extra = nil
|
||||
z.ModTime = time.Time{}
|
||||
z.wroteHeader = false
|
||||
z.currentBuffer = nil
|
||||
z.buf = [10]byte{}
|
||||
z.prevTail = nil
|
||||
z.size = 0
|
||||
if z.dictFlatePool.New == nil {
|
||||
z.dictFlatePool.New = func() interface{} {
|
||||
f, _ := flate.NewWriterDict(w, level, nil)
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset discards the Writer z's state and makes it equivalent to the
|
||||
// result of its original state from NewWriter or NewWriterLevel, but
|
||||
// writing to w instead. This permits reusing a Writer rather than
|
||||
// allocating a new one.
|
||||
func (z *Writer) Reset(w io.Writer) {
|
||||
if z.results != nil && !z.closed {
|
||||
close(z.results)
|
||||
}
|
||||
z.SetConcurrency(defaultBlockSize, defaultBlocks)
|
||||
z.init(w, z.level)
|
||||
}
|
||||
|
||||
// GZIP (RFC 1952) is little-endian, unlike ZLIB (RFC 1950).
|
||||
func put2(p []byte, v uint16) {
|
||||
p[0] = uint8(v >> 0)
|
||||
p[1] = uint8(v >> 8)
|
||||
}
|
||||
|
||||
func put4(p []byte, v uint32) {
|
||||
p[0] = uint8(v >> 0)
|
||||
p[1] = uint8(v >> 8)
|
||||
p[2] = uint8(v >> 16)
|
||||
p[3] = uint8(v >> 24)
|
||||
}
|
||||
|
||||
// writeBytes writes a length-prefixed byte slice to z.w.
|
||||
func (z *Writer) writeBytes(b []byte) error {
|
||||
if len(b) > 0xffff {
|
||||
return errors.New("gzip.Write: Extra data is too large")
|
||||
}
|
||||
put2(z.buf[0:2], uint16(len(b)))
|
||||
_, err := z.w.Write(z.buf[0:2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = z.w.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// writeString writes a UTF-8 string s in GZIP's format to z.w.
|
||||
// GZIP (RFC 1952) specifies that strings are NUL-terminated ISO 8859-1 (Latin-1).
|
||||
func (z *Writer) writeString(s string) (err error) {
|
||||
// GZIP stores Latin-1 strings; error if non-Latin-1; convert if non-ASCII.
|
||||
needconv := false
|
||||
for _, v := range s {
|
||||
if v == 0 || v > 0xff {
|
||||
return errors.New("gzip.Write: non-Latin-1 header string")
|
||||
}
|
||||
if v > 0x7f {
|
||||
needconv = true
|
||||
}
|
||||
}
|
||||
if needconv {
|
||||
b := make([]byte, 0, len(s))
|
||||
for _, v := range s {
|
||||
b = append(b, byte(v))
|
||||
}
|
||||
_, err = z.w.Write(b)
|
||||
} else {
|
||||
_, err = io.WriteString(z.w, s)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// GZIP strings are NUL-terminated.
|
||||
z.buf[0] = 0
|
||||
_, err = z.w.Write(z.buf[0:1])
|
||||
return err
|
||||
}
|
||||
|
||||
// compressCurrent will compress the data currently buffered
|
||||
// This should only be called from the main writer/flush/closer
|
||||
func (z *Writer) compressCurrent(flush bool) {
|
||||
r := result{}
|
||||
r.result = make(chan []byte, 1)
|
||||
r.notifyWritten = make(chan struct{}, 0)
|
||||
select {
|
||||
case z.results <- r:
|
||||
case <-z.pushedErr:
|
||||
return
|
||||
}
|
||||
|
||||
// If block given is more than twice the block size, split it.
|
||||
c := z.currentBuffer
|
||||
if len(c) > z.blockSize*2 {
|
||||
c = c[:z.blockSize]
|
||||
z.wg.Add(1)
|
||||
go z.compressBlock(c, z.prevTail, r, false)
|
||||
z.prevTail = c[len(c)-tailSize:]
|
||||
z.currentBuffer = z.currentBuffer[z.blockSize:]
|
||||
z.compressCurrent(flush)
|
||||
// Last one flushes if needed
|
||||
return
|
||||
}
|
||||
|
||||
z.wg.Add(1)
|
||||
go z.compressBlock(c, z.prevTail, r, z.closed)
|
||||
if len(c) > tailSize {
|
||||
z.prevTail = c[len(c)-tailSize:]
|
||||
} else {
|
||||
z.prevTail = nil
|
||||
}
|
||||
z.currentBuffer = z.dstPool.Get().([]byte)
|
||||
z.currentBuffer = z.currentBuffer[:0]
|
||||
|
||||
// Wait if flushing
|
||||
if flush {
|
||||
<-r.notifyWritten
|
||||
}
|
||||
}
|
||||
|
||||
// Returns an error if it has been set.
|
||||
// Cannot be used by functions that are from internal goroutines.
|
||||
func (z *Writer) checkError() error {
|
||||
z.errMu.RLock()
|
||||
err := z.err
|
||||
z.errMu.RUnlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Write writes a compressed form of p to the underlying io.Writer. The
|
||||
// compressed bytes are not necessarily flushed to output until
|
||||
// the Writer is closed or Flush() is called.
|
||||
//
|
||||
// The function will return quickly, if there are unused buffers.
|
||||
// The sent slice (p) is copied, and the caller is free to re-use the buffer
|
||||
// when the function returns.
|
||||
//
|
||||
// Errors that occur during compression will be reported later, and a nil error
|
||||
// does not signify that the compression succeeded (since it is most likely still running)
|
||||
// That means that the call that returns an error may not be the call that caused it.
|
||||
// Only Flush and Close functions are guaranteed to return any errors up to that point.
|
||||
func (z *Writer) Write(p []byte) (int, error) {
|
||||
if err := z.checkError(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Write the GZIP header lazily.
|
||||
if !z.wroteHeader {
|
||||
z.wroteHeader = true
|
||||
z.buf[0] = gzipID1
|
||||
z.buf[1] = gzipID2
|
||||
z.buf[2] = gzipDeflate
|
||||
z.buf[3] = 0
|
||||
if z.Extra != nil {
|
||||
z.buf[3] |= 0x04
|
||||
}
|
||||
if z.Name != "" {
|
||||
z.buf[3] |= 0x08
|
||||
}
|
||||
if z.Comment != "" {
|
||||
z.buf[3] |= 0x10
|
||||
}
|
||||
put4(z.buf[4:8], uint32(z.ModTime.Unix()))
|
||||
if z.level == BestCompression {
|
||||
z.buf[8] = 2
|
||||
} else if z.level == BestSpeed {
|
||||
z.buf[8] = 4
|
||||
} else {
|
||||
z.buf[8] = 0
|
||||
}
|
||||
z.buf[9] = z.OS
|
||||
var n int
|
||||
var err error
|
||||
n, err = z.w.Write(z.buf[0:10])
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return n, err
|
||||
}
|
||||
if z.Extra != nil {
|
||||
err = z.writeBytes(z.Extra)
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
if z.Name != "" {
|
||||
err = z.writeString(z.Name)
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
if z.Comment != "" {
|
||||
err = z.writeString(z.Comment)
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
// Start receiving data from compressors
|
||||
go func() {
|
||||
listen := z.results
|
||||
for {
|
||||
r, ok := <-listen
|
||||
// If closed, we are finished.
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
buf := <-r.result
|
||||
n, err := z.w.Write(buf)
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
close(r.notifyWritten)
|
||||
return
|
||||
}
|
||||
if n != len(buf) {
|
||||
z.pushError(fmt.Errorf("gzip: short write %d should be %d", n, len(buf)))
|
||||
close(r.notifyWritten)
|
||||
return
|
||||
}
|
||||
z.dstPool.Put(buf)
|
||||
close(r.notifyWritten)
|
||||
}
|
||||
}()
|
||||
z.currentBuffer = make([]byte, 0, z.blockSize)
|
||||
}
|
||||
q := p
|
||||
for len(q) > 0 {
|
||||
length := len(q)
|
||||
if length+len(z.currentBuffer) > z.blockSize {
|
||||
length = z.blockSize - len(z.currentBuffer)
|
||||
}
|
||||
z.digest.Write(q[:length])
|
||||
z.currentBuffer = append(z.currentBuffer, q[:length]...)
|
||||
if len(z.currentBuffer) >= z.blockSize {
|
||||
z.compressCurrent(false)
|
||||
if err := z.checkError(); err != nil {
|
||||
return len(p) - len(q) - length, err
|
||||
}
|
||||
}
|
||||
z.size += length
|
||||
q = q[length:]
|
||||
}
|
||||
return len(p), z.checkError()
|
||||
}
|
||||
|
||||
// Step 1: compresses buffer to buffer
|
||||
// Step 2: send writer to channel
|
||||
// Step 3: Close result channel to indicate we are done
|
||||
func (z *Writer) compressBlock(p, prevTail []byte, r result, closed bool) {
|
||||
defer func() {
|
||||
close(r.result)
|
||||
z.wg.Done()
|
||||
}()
|
||||
buf := z.dstPool.Get().([]byte)
|
||||
dest := bytes.NewBuffer(buf[:0])
|
||||
|
||||
compressor := z.dictFlatePool.Get().(*flate.Writer)
|
||||
compressor.ResetDict(dest, prevTail)
|
||||
compressor.Write(p)
|
||||
|
||||
err := compressor.Flush()
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return
|
||||
}
|
||||
if closed {
|
||||
err = compressor.Close()
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
z.dictFlatePool.Put(compressor)
|
||||
// Read back buffer
|
||||
buf = dest.Bytes()
|
||||
r.result <- buf
|
||||
}
|
||||
|
||||
// Flush flushes any pending compressed data to the underlying writer.
|
||||
//
|
||||
// It is useful mainly in compressed network protocols, to ensure that
|
||||
// a remote reader has enough data to reconstruct a packet. Flush does
|
||||
// not return until the data has been written. If the underlying
|
||||
// writer returns an error, Flush returns that error.
|
||||
//
|
||||
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
|
||||
func (z *Writer) Flush() error {
|
||||
if err := z.checkError(); err != nil {
|
||||
return err
|
||||
}
|
||||
if z.closed {
|
||||
return nil
|
||||
}
|
||||
if !z.wroteHeader {
|
||||
_, err := z.Write(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// We send current block to compression
|
||||
z.compressCurrent(true)
|
||||
|
||||
return z.checkError()
|
||||
}
|
||||
|
||||
// UncompressedSize will return the number of bytes written.
|
||||
// pgzip only, not a function in the official gzip package.
|
||||
func (z *Writer) UncompressedSize() int {
|
||||
return z.size
|
||||
}
|
||||
|
||||
// Close closes the Writer, flushing any unwritten data to the underlying
|
||||
// io.Writer, but does not close the underlying io.Writer.
|
||||
func (z *Writer) Close() error {
|
||||
if err := z.checkError(); err != nil {
|
||||
return err
|
||||
}
|
||||
if z.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
z.closed = true
|
||||
if !z.wroteHeader {
|
||||
z.Write(nil)
|
||||
if err := z.checkError(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
z.compressCurrent(true)
|
||||
if err := z.checkError(); err != nil {
|
||||
return err
|
||||
}
|
||||
close(z.results)
|
||||
put4(z.buf[0:4], z.digest.Sum32())
|
||||
put4(z.buf[4:8], uint32(z.size))
|
||||
_, err := z.w.Write(z.buf[0:8])
|
||||
if err != nil {
|
||||
z.pushError(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
This project is originally a fork of [https://github.com/youtube/vitess](https://github.com/youtube/vitess)
|
||||
Copyright Google Inc
|
||||
|
||||
# Contributors
|
||||
Wenbin Xiao 2015
|
||||
Started this project and maintained it.
|
||||
|
||||
Andrew Brampton 2017
|
||||
Merged in multiple upstream fixes/changes.
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2017 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
MAKEFLAGS = -s
|
||||
|
||||
sql.go: sql.y
|
||||
goyacc -o sql.go sql.y
|
||||
gofmt -w sql.go
|
||||
|
||||
clean:
|
||||
rm -f y.output sql.go
|
|
@ -0,0 +1,150 @@
|
|||
# sqlparser [![Build Status](https://img.shields.io/travis/xwb1989/sqlparser.svg)](https://travis-ci.org/xwb1989/sqlparser) [![Coverage](https://img.shields.io/coveralls/xwb1989/sqlparser.svg)](https://coveralls.io/github/xwb1989/sqlparser) [![Report card](https://goreportcard.com/badge/github.com/xwb1989/sqlparser)](https://goreportcard.com/report/github.com/xwb1989/sqlparser) [![GoDoc](https://godoc.org/github.com/xwb1989/sqlparser?status.svg)](https://godoc.org/github.com/xwb1989/sqlparser)
|
||||
|
||||
Go package for parsing MySQL SQL queries.
|
||||
|
||||
## Notice
|
||||
|
||||
The backbone of this repo is extracted from [vitessio/vitess](https://github.com/vitessio/vitess).
|
||||
|
||||
Inside vitessio/vitess there is a very nicely written sql parser. However as it's not a self-contained application, I created this one.
|
||||
It applies the same LICENSE as vitessio/vitess.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
```
|
||||
|
||||
Then use:
|
||||
|
||||
```go
|
||||
sql := "SELECT * FROM table WHERE a = 'abc'"
|
||||
stmt, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
// Do something with the err
|
||||
}
|
||||
|
||||
// Otherwise do something with stmt
|
||||
switch stmt := stmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
_ = stmt
|
||||
case *sqlparser.Insert:
|
||||
}
|
||||
```
|
||||
|
||||
Alternative to read many queries from a io.Reader:
|
||||
|
||||
```go
|
||||
r := strings.NewReader("INSERT INTO table1 VALUES (1, 'a'); INSERT INTO table2 VALUES (3, 4);")
|
||||
|
||||
tokens := sqlparser.NewTokenizer(r)
|
||||
for {
|
||||
stmt, err := sqlparser.ParseNext(tokens)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
// Do something with stmt or err.
|
||||
}
|
||||
```
|
||||
|
||||
See [parse_test.go](https://github.com/xwb1989/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/xwb1989/sqlparser).
|
||||
|
||||
|
||||
## Porting Instructions
|
||||
|
||||
You only need the below if you plan to try and keep this library up to date with [vitessio/vitess](https://github.com/vitessio/vitess).
|
||||
|
||||
### Keeping up to date
|
||||
|
||||
```bash
|
||||
shopt -s nullglob
|
||||
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
|
||||
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
|
||||
|
||||
# Create patches for everything that changed
|
||||
LASTIMPORT=1b7879cb91f1dfe1a2dfa06fea96e951e3a7aec5
|
||||
for path in ${VITESS?}/{vt/sqlparser,sqltypes,bytes2,hack}; do
|
||||
cd ${path}
|
||||
git format-patch ${LASTIMPORT?} .
|
||||
done;
|
||||
|
||||
# Apply patches to the dependencies
|
||||
cd ${XWB1989?}
|
||||
git am --directory dependency -p2 ${VITESS?}/{sqltypes,bytes2,hack}/*.patch
|
||||
|
||||
# Apply the main patches to the repo
|
||||
cd ${XWB1989?}
|
||||
git am -p4 ${VITESS?}/vt/sqlparser/*.patch
|
||||
|
||||
# If you encounter diff failures, manually fix them with
|
||||
patch -p4 < .git/rebase-apply/patch
|
||||
...
|
||||
git add name_of_files
|
||||
git am --continue
|
||||
|
||||
# Cleanup
|
||||
rm ${VITESS?}/{sqltypes,bytes2,hack}/*.patch ${VITESS?}/*.patch
|
||||
|
||||
# and Finally update the LASTIMPORT in this README.
|
||||
```
|
||||
|
||||
### Fresh install
|
||||
|
||||
TODO: Change these instructions to use git to copy the files, that'll make later patching easier.
|
||||
|
||||
```bash
|
||||
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
|
||||
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
|
||||
|
||||
cd ${XWB1989?}
|
||||
|
||||
# Copy all the code
|
||||
cp -pr ${VITESS?}/vt/sqlparser/ .
|
||||
cp -pr ${VITESS?}/sqltypes dependency
|
||||
cp -pr ${VITESS?}/bytes2 dependency
|
||||
cp -pr ${VITESS?}/hack dependency
|
||||
|
||||
# Delete some code we haven't ported
|
||||
rm dependency/sqltypes/arithmetic.go dependency/sqltypes/arithmetic_test.go dependency/sqltypes/event_token.go dependency/sqltypes/event_token_test.go dependency/sqltypes/proto3.go dependency/sqltypes/proto3_test.go dependency/sqltypes/query_response.go dependency/sqltypes/result.go dependency/sqltypes/result_test.go
|
||||
|
||||
# Some automated fixes
|
||||
|
||||
# Fix imports
|
||||
sed -i '.bak' 's_vitess.io/vitess/go/vt/proto/query_github.com/xwb1989/sqlparser/dependency/querypb_g' *.go dependency/sqltypes/*.go
|
||||
sed -i '.bak' 's_vitess.io/vitess/go/_github.com/xwb1989/sqlparser/dependency/_g' *.go dependency/sqltypes/*.go
|
||||
|
||||
# Copy the proto, but basically drop everything we don't want
|
||||
cp -pr ${VITESS?}/vt/proto/query dependency/querypb
|
||||
|
||||
sed -i '.bak' 's_.*Descriptor.*__g' dependency/querypb/*.go
|
||||
sed -i '.bak' 's_.*ProtoMessage.*__g' dependency/querypb/*.go
|
||||
|
||||
sed -i '.bak' 's/proto.CompactTextString(m)/"TODO"/g' dependency/querypb/*.go
|
||||
sed -i '.bak' 's/proto.EnumName/EnumName/g' dependency/querypb/*.go
|
||||
|
||||
sed -i '.bak' 's/proto.Equal/reflect.DeepEqual/g' dependency/sqltypes/*.go
|
||||
|
||||
# Remove the error library
|
||||
sed -i '.bak' 's/vterrors.Errorf([^,]*, /fmt.Errorf(/g' *.go dependency/sqltypes/*.go
|
||||
sed -i '.bak' 's/vterrors.New([^,]*, /errors.New(/g' *.go dependency/sqltypes/*.go
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
VITESS=${GOPATH?}/src/vitess.io/vitess/go/
|
||||
XWB1989=${GOPATH?}/src/github.com/xwb1989/sqlparser/
|
||||
|
||||
cd ${XWB1989?}
|
||||
|
||||
# Test, fix and repeat
|
||||
go test ./...
|
||||
|
||||
# Finally make some diffs (for later reference)
|
||||
diff -u ${VITESS?}/sqltypes/ ${XWB1989?}/dependency/sqltypes/ > ${XWB1989?}/patches/sqltypes.patch
|
||||
diff -u ${VITESS?}/bytes2/ ${XWB1989?}/dependency/bytes2/ > ${XWB1989?}/patches/bytes2.patch
|
||||
diff -u ${VITESS?}/vt/proto/query/ ${XWB1989?}/dependency/querypb/ > ${XWB1989?}/patches/querypb.patch
|
||||
diff -u ${VITESS?}/vt/sqlparser/ ${XWB1989?}/ > ${XWB1989?}/patches/sqlparser.patch
|
||||
```
|
|
@ -0,0 +1,343 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
// analyzer.go contains utility analysis functions.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
// These constants are used to identify the SQL statement type.
|
||||
const (
|
||||
StmtSelect = iota
|
||||
StmtStream
|
||||
StmtInsert
|
||||
StmtReplace
|
||||
StmtUpdate
|
||||
StmtDelete
|
||||
StmtDDL
|
||||
StmtBegin
|
||||
StmtCommit
|
||||
StmtRollback
|
||||
StmtSet
|
||||
StmtShow
|
||||
StmtUse
|
||||
StmtOther
|
||||
StmtUnknown
|
||||
StmtComment
|
||||
)
|
||||
|
||||
// Preview analyzes the beginning of the query using a simpler and faster
|
||||
// textual comparison to identify the statement type.
|
||||
func Preview(sql string) int {
|
||||
trimmed := StripLeadingComments(sql)
|
||||
|
||||
firstWord := trimmed
|
||||
if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 {
|
||||
firstWord = trimmed[:end]
|
||||
}
|
||||
firstWord = strings.TrimLeftFunc(firstWord, func(r rune) bool { return !unicode.IsLetter(r) })
|
||||
// Comparison is done in order of priority.
|
||||
loweredFirstWord := strings.ToLower(firstWord)
|
||||
switch loweredFirstWord {
|
||||
case "select":
|
||||
return StmtSelect
|
||||
case "stream":
|
||||
return StmtStream
|
||||
case "insert":
|
||||
return StmtInsert
|
||||
case "replace":
|
||||
return StmtReplace
|
||||
case "update":
|
||||
return StmtUpdate
|
||||
case "delete":
|
||||
return StmtDelete
|
||||
}
|
||||
// For the following statements it is not sufficient to rely
|
||||
// on loweredFirstWord. This is because they are not statements
|
||||
// in the grammar and we are relying on Preview to parse them.
|
||||
// For instance, we don't want: "BEGIN JUNK" to be parsed
|
||||
// as StmtBegin.
|
||||
trimmedNoComments, _ := SplitMarginComments(trimmed)
|
||||
switch strings.ToLower(trimmedNoComments) {
|
||||
case "begin", "start transaction":
|
||||
return StmtBegin
|
||||
case "commit":
|
||||
return StmtCommit
|
||||
case "rollback":
|
||||
return StmtRollback
|
||||
}
|
||||
switch loweredFirstWord {
|
||||
case "create", "alter", "rename", "drop", "truncate":
|
||||
return StmtDDL
|
||||
case "set":
|
||||
return StmtSet
|
||||
case "show":
|
||||
return StmtShow
|
||||
case "use":
|
||||
return StmtUse
|
||||
case "analyze", "describe", "desc", "explain", "repair", "optimize":
|
||||
return StmtOther
|
||||
}
|
||||
if strings.Index(trimmed, "/*!") == 0 {
|
||||
return StmtComment
|
||||
}
|
||||
return StmtUnknown
|
||||
}
|
||||
|
||||
// StmtType returns the statement type as a string
|
||||
func StmtType(stmtType int) string {
|
||||
switch stmtType {
|
||||
case StmtSelect:
|
||||
return "SELECT"
|
||||
case StmtStream:
|
||||
return "STREAM"
|
||||
case StmtInsert:
|
||||
return "INSERT"
|
||||
case StmtReplace:
|
||||
return "REPLACE"
|
||||
case StmtUpdate:
|
||||
return "UPDATE"
|
||||
case StmtDelete:
|
||||
return "DELETE"
|
||||
case StmtDDL:
|
||||
return "DDL"
|
||||
case StmtBegin:
|
||||
return "BEGIN"
|
||||
case StmtCommit:
|
||||
return "COMMIT"
|
||||
case StmtRollback:
|
||||
return "ROLLBACK"
|
||||
case StmtSet:
|
||||
return "SET"
|
||||
case StmtShow:
|
||||
return "SHOW"
|
||||
case StmtUse:
|
||||
return "USE"
|
||||
case StmtOther:
|
||||
return "OTHER"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
|
||||
func IsDML(sql string) bool {
|
||||
switch Preview(sql) {
|
||||
case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetTableName returns the table name from the SimpleTableExpr
|
||||
// only if it's a simple expression. Otherwise, it returns "".
|
||||
func GetTableName(node SimpleTableExpr) TableIdent {
|
||||
if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
|
||||
return n.Name
|
||||
}
|
||||
// sub-select or '.' expression
|
||||
return NewTableIdent("")
|
||||
}
|
||||
|
||||
// IsColName returns true if the Expr is a *ColName.
|
||||
func IsColName(node Expr) bool {
|
||||
_, ok := node.(*ColName)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsValue returns true if the Expr is a string, integral or value arg.
|
||||
// NULL is not considered to be a value.
|
||||
func IsValue(node Expr) bool {
|
||||
switch v := node.(type) {
|
||||
case *SQLVal:
|
||||
switch v.Type {
|
||||
case StrVal, HexVal, IntVal, ValArg:
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsNull returns true if the Expr is SQL NULL
|
||||
func IsNull(node Expr) bool {
|
||||
switch node.(type) {
|
||||
case *NullVal:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSimpleTuple returns true if the Expr is a ValTuple that
|
||||
// contains simple values or if it's a list arg.
|
||||
func IsSimpleTuple(node Expr) bool {
|
||||
switch vals := node.(type) {
|
||||
case ValTuple:
|
||||
for _, n := range vals {
|
||||
if !IsValue(n) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case ListArg:
|
||||
return true
|
||||
}
|
||||
// It's a subquery
|
||||
return false
|
||||
}
|
||||
|
||||
// NewPlanValue builds a sqltypes.PlanValue from an Expr.
|
||||
func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
|
||||
switch node := node.(type) {
|
||||
case *SQLVal:
|
||||
switch node.Type {
|
||||
case ValArg:
|
||||
return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
|
||||
case IntVal:
|
||||
n, err := sqltypes.NewIntegral(string(node.Val))
|
||||
if err != nil {
|
||||
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
|
||||
}
|
||||
return sqltypes.PlanValue{Value: n}, nil
|
||||
case StrVal:
|
||||
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
|
||||
case HexVal:
|
||||
v, err := node.HexDecode()
|
||||
if err != nil {
|
||||
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
|
||||
}
|
||||
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil
|
||||
}
|
||||
case ListArg:
|
||||
return sqltypes.PlanValue{ListKey: string(node[2:])}, nil
|
||||
case ValTuple:
|
||||
pv := sqltypes.PlanValue{
|
||||
Values: make([]sqltypes.PlanValue, 0, len(node)),
|
||||
}
|
||||
for _, val := range node {
|
||||
innerpv, err := NewPlanValue(val)
|
||||
if err != nil {
|
||||
return sqltypes.PlanValue{}, err
|
||||
}
|
||||
if innerpv.ListKey != "" || innerpv.Values != nil {
|
||||
return sqltypes.PlanValue{}, errors.New("unsupported: nested lists")
|
||||
}
|
||||
pv.Values = append(pv.Values, innerpv)
|
||||
}
|
||||
return pv, nil
|
||||
case *NullVal:
|
||||
return sqltypes.PlanValue{}, nil
|
||||
}
|
||||
return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node))
|
||||
}
|
||||
|
||||
// StringIn is a convenience function that returns
|
||||
// true if str matches any of the values.
|
||||
func StringIn(str string, values ...string) bool {
|
||||
for _, val := range values {
|
||||
if str == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetKey is the extracted key from one SetExpr
|
||||
type SetKey struct {
|
||||
Key string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// ExtractSetValues returns a map of key-value pairs
|
||||
// if the query is a SET statement. Values can be bool, int64 or string.
|
||||
// Since set variable names are case insensitive, all keys are returned
|
||||
// as lower case.
|
||||
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
|
||||
stmt, err := Parse(sql)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
setStmt, ok := stmt.(*Set)
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
|
||||
}
|
||||
result := make(map[SetKey]interface{})
|
||||
for _, expr := range setStmt.Exprs {
|
||||
scope := SessionStr
|
||||
key := expr.Name.Lowered()
|
||||
switch {
|
||||
case strings.HasPrefix(key, "@@global."):
|
||||
scope = GlobalStr
|
||||
key = strings.TrimPrefix(key, "@@global.")
|
||||
case strings.HasPrefix(key, "@@session."):
|
||||
key = strings.TrimPrefix(key, "@@session.")
|
||||
case strings.HasPrefix(key, "@@"):
|
||||
key = strings.TrimPrefix(key, "@@")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr.Name.Lowered(), "@@") {
|
||||
if setStmt.Scope != "" && scope != "" {
|
||||
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
|
||||
}
|
||||
_, out := NewStringTokenizer(key).Scan()
|
||||
key = string(out)
|
||||
}
|
||||
|
||||
setKey := SetKey{
|
||||
Key: key,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
switch expr := expr.Expr.(type) {
|
||||
case *SQLVal:
|
||||
switch expr.Type {
|
||||
case StrVal:
|
||||
result[setKey] = strings.ToLower(string(expr.Val))
|
||||
case IntVal:
|
||||
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
result[setKey] = num
|
||||
default:
|
||||
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
|
||||
}
|
||||
case BoolVal:
|
||||
var val int64
|
||||
if expr {
|
||||
val = 1
|
||||
}
|
||||
result[setKey] = val
|
||||
case *ColName:
|
||||
result[setKey] = expr.Name.String()
|
||||
case *NullVal:
|
||||
result[setKey] = nil
|
||||
case *Default:
|
||||
result[setKey] = "default"
|
||||
default:
|
||||
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
|
||||
}
|
||||
}
|
||||
return result, strings.ToLower(setStmt.Scope), nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,293 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const (
|
||||
// DirectiveMultiShardAutocommit is the query comment directive to allow
|
||||
// single round trip autocommit with a multi-shard statement.
|
||||
DirectiveMultiShardAutocommit = "MULTI_SHARD_AUTOCOMMIT"
|
||||
// DirectiveSkipQueryPlanCache skips query plan cache when set.
|
||||
DirectiveSkipQueryPlanCache = "SKIP_QUERY_PLAN_CACHE"
|
||||
// DirectiveQueryTimeout sets a query timeout in vtgate. Only supported for SELECTS.
|
||||
DirectiveQueryTimeout = "QUERY_TIMEOUT_MS"
|
||||
)
|
||||
|
||||
func isNonSpace(r rune) bool {
|
||||
return !unicode.IsSpace(r)
|
||||
}
|
||||
|
||||
// leadingCommentEnd returns the first index after all leading comments, or
|
||||
// 0 if there are no leading comments.
|
||||
func leadingCommentEnd(text string) (end int) {
|
||||
hasComment := false
|
||||
pos := 0
|
||||
for pos < len(text) {
|
||||
// Eat up any whitespace. Trailing whitespace will be considered part of
|
||||
// the leading comments.
|
||||
nextVisibleOffset := strings.IndexFunc(text[pos:], isNonSpace)
|
||||
if nextVisibleOffset < 0 {
|
||||
break
|
||||
}
|
||||
pos += nextVisibleOffset
|
||||
remainingText := text[pos:]
|
||||
|
||||
// Found visible characters. Look for '/*' at the beginning
|
||||
// and '*/' somewhere after that.
|
||||
if len(remainingText) < 4 || remainingText[:2] != "/*" {
|
||||
break
|
||||
}
|
||||
commentLength := 4 + strings.Index(remainingText[2:], "*/")
|
||||
if commentLength < 4 {
|
||||
// Missing end comment :/
|
||||
break
|
||||
}
|
||||
|
||||
hasComment = true
|
||||
pos += commentLength
|
||||
}
|
||||
|
||||
if hasComment {
|
||||
return pos
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingCommentStart returns the first index of trailing comments.
|
||||
// If there are no trailing comments, returns the length of the input string.
|
||||
func trailingCommentStart(text string) (start int) {
|
||||
hasComment := false
|
||||
reducedLen := len(text)
|
||||
for reducedLen > 0 {
|
||||
// Eat up any whitespace. Leading whitespace will be considered part of
|
||||
// the trailing comments.
|
||||
nextReducedLen := strings.LastIndexFunc(text[:reducedLen], isNonSpace) + 1
|
||||
if nextReducedLen == 0 {
|
||||
break
|
||||
}
|
||||
reducedLen = nextReducedLen
|
||||
if reducedLen < 4 || text[reducedLen-2:reducedLen] != "*/" {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the beginning of the comment
|
||||
startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*")
|
||||
if startCommentPos < 0 {
|
||||
// Badly formatted sql :/
|
||||
break
|
||||
}
|
||||
|
||||
hasComment = true
|
||||
reducedLen = startCommentPos
|
||||
}
|
||||
|
||||
if hasComment {
|
||||
return reducedLen
|
||||
}
|
||||
return len(text)
|
||||
}
|
||||
|
||||
// MarginComments holds the leading and trailing comments that surround a query.
|
||||
type MarginComments struct {
|
||||
Leading string
|
||||
Trailing string
|
||||
}
|
||||
|
||||
// SplitMarginComments pulls out any leading or trailing comments from a raw sql query.
|
||||
// This function also trims leading (if there's a comment) and trailing whitespace.
|
||||
func SplitMarginComments(sql string) (query string, comments MarginComments) {
|
||||
trailingStart := trailingCommentStart(sql)
|
||||
leadingEnd := leadingCommentEnd(sql[:trailingStart])
|
||||
comments = MarginComments{
|
||||
Leading: strings.TrimLeftFunc(sql[:leadingEnd], unicode.IsSpace),
|
||||
Trailing: strings.TrimRightFunc(sql[trailingStart:], unicode.IsSpace),
|
||||
}
|
||||
return strings.TrimFunc(sql[leadingEnd:trailingStart], unicode.IsSpace), comments
|
||||
}
|
||||
|
||||
// StripLeadingComments trims the SQL string and removes any leading comments
|
||||
func StripLeadingComments(sql string) string {
|
||||
sql = strings.TrimFunc(sql, unicode.IsSpace)
|
||||
|
||||
for hasCommentPrefix(sql) {
|
||||
switch sql[0] {
|
||||
case '/':
|
||||
// Multi line comment
|
||||
index := strings.Index(sql, "*/")
|
||||
if index <= 1 {
|
||||
return sql
|
||||
}
|
||||
// don't strip /*! ... */ or /*!50700 ... */
|
||||
if len(sql) > 2 && sql[2] == '!' {
|
||||
return sql
|
||||
}
|
||||
sql = sql[index+2:]
|
||||
case '-':
|
||||
// Single line comment
|
||||
index := strings.Index(sql, "\n")
|
||||
if index == -1 {
|
||||
return sql
|
||||
}
|
||||
sql = sql[index+1:]
|
||||
}
|
||||
|
||||
sql = strings.TrimFunc(sql, unicode.IsSpace)
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func hasCommentPrefix(sql string) bool {
|
||||
return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-'))
|
||||
}
|
||||
|
||||
// ExtractMysqlComment extracts the version and SQL from a comment-only query
|
||||
// such as /*!50708 sql here */
|
||||
func ExtractMysqlComment(sql string) (version string, innerSQL string) {
|
||||
sql = sql[3 : len(sql)-2]
|
||||
|
||||
digitCount := 0
|
||||
endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
|
||||
digitCount++
|
||||
return !unicode.IsDigit(c) || digitCount == 6
|
||||
})
|
||||
version = sql[0:endOfVersionIndex]
|
||||
innerSQL = strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
|
||||
|
||||
return version, innerSQL
|
||||
}
|
||||
|
||||
const commentDirectivePreamble = "/*vt+"
|
||||
|
||||
// CommentDirectives is the parsed representation for execution directives
|
||||
// conveyed in query comments
|
||||
type CommentDirectives map[string]interface{}
|
||||
|
||||
// ExtractCommentDirectives parses the comment list for any execution directives
|
||||
// of the form:
|
||||
//
|
||||
// /*vt+ OPTION_ONE=1 OPTION_TWO OPTION_THREE=abcd */
|
||||
//
|
||||
// It returns the map of the directive values or nil if there aren't any.
|
||||
func ExtractCommentDirectives(comments Comments) CommentDirectives {
|
||||
if comments == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var vals map[string]interface{}
|
||||
|
||||
for _, comment := range comments {
|
||||
commentStr := string(comment)
|
||||
if commentStr[0:5] != commentDirectivePreamble {
|
||||
continue
|
||||
}
|
||||
|
||||
if vals == nil {
|
||||
vals = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Split on whitespace and ignore the first and last directive
|
||||
// since they contain the comment start/end
|
||||
directives := strings.Fields(commentStr)
|
||||
for i := 1; i < len(directives)-1; i++ {
|
||||
directive := directives[i]
|
||||
sep := strings.IndexByte(directive, '=')
|
||||
|
||||
// No value is equivalent to a true boolean
|
||||
if sep == -1 {
|
||||
vals[directive] = true
|
||||
continue
|
||||
}
|
||||
|
||||
strVal := directive[sep+1:]
|
||||
directive = directive[:sep]
|
||||
|
||||
intVal, err := strconv.Atoi(strVal)
|
||||
if err == nil {
|
||||
vals[directive] = intVal
|
||||
continue
|
||||
}
|
||||
|
||||
boolVal, err := strconv.ParseBool(strVal)
|
||||
if err == nil {
|
||||
vals[directive] = boolVal
|
||||
continue
|
||||
}
|
||||
|
||||
vals[directive] = strVal
|
||||
}
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
// IsSet checks the directive map for the named directive and returns
|
||||
// true if the directive is set and has a true/false or 0/1 value
|
||||
func (d CommentDirectives) IsSet(key string) bool {
|
||||
if d == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
val, ok := d[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
boolVal, ok := val.(bool)
|
||||
if ok {
|
||||
return boolVal
|
||||
}
|
||||
|
||||
intVal, ok := val.(int)
|
||||
if ok {
|
||||
return intVal == 1
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SkipQueryPlanCacheDirective returns true if skip query plan cache directive is set to true in query.
|
||||
func SkipQueryPlanCacheDirective(stmt Statement) bool {
|
||||
switch stmt := stmt.(type) {
|
||||
case *Select:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
case *Insert:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
case *Update:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
case *Delete:
|
||||
directives := ExtractCommentDirectives(stmt.Comments)
|
||||
if directives.IsSet(DirectiveSkipQueryPlanCache) {
|
||||
return true
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package bytes2
|
||||
|
||||
// Buffer implements a subset of the write portion of
|
||||
// bytes.Buffer, but more efficiently. This is meant to
|
||||
// be used in very high QPS operations, especially for
|
||||
// WriteByte, and without abstracting it as a Writer.
|
||||
// Function signatures contain errors for compatibility,
|
||||
// but they do not return errors.
|
||||
type Buffer struct {
|
||||
bytes []byte
|
||||
}
|
||||
|
||||
// NewBuffer is equivalent to bytes.NewBuffer.
|
||||
func NewBuffer(b []byte) *Buffer {
|
||||
return &Buffer{bytes: b}
|
||||
}
|
||||
|
||||
// Write is equivalent to bytes.Buffer.Write.
|
||||
func (buf *Buffer) Write(b []byte) (int, error) {
|
||||
buf.bytes = append(buf.bytes, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// WriteString is equivalent to bytes.Buffer.WriteString.
|
||||
func (buf *Buffer) WriteString(s string) (int, error) {
|
||||
buf.bytes = append(buf.bytes, s...)
|
||||
return len(s), nil
|
||||
}
|
||||
|
||||
// WriteByte is equivalent to bytes.Buffer.WriteByte.
|
||||
func (buf *Buffer) WriteByte(b byte) error {
|
||||
buf.bytes = append(buf.bytes, b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes is equivalent to bytes.Buffer.Bytes.
|
||||
func (buf *Buffer) Bytes() []byte {
|
||||
return buf.bytes
|
||||
}
|
||||
|
||||
// Strings is equivalent to bytes.Buffer.Strings.
|
||||
func (buf *Buffer) String() string {
|
||||
return string(buf.bytes)
|
||||
}
|
||||
|
||||
// Len is equivalent to bytes.Buffer.Len.
|
||||
func (buf *Buffer) Len() int {
|
||||
return len(buf.bytes)
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package hack gives you some efficient functionality at the cost of
|
||||
// breaking some Go rules.
|
||||
package hack
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// StringArena lets you consolidate allocations for a group of strings
|
||||
// that have similar life length
|
||||
type StringArena struct {
|
||||
buf []byte
|
||||
str string
|
||||
}
|
||||
|
||||
// NewStringArena creates an arena of the specified size.
|
||||
func NewStringArena(size int) *StringArena {
|
||||
sa := &StringArena{buf: make([]byte, 0, size)}
|
||||
pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&sa.buf))
|
||||
pstring := (*reflect.StringHeader)(unsafe.Pointer(&sa.str))
|
||||
pstring.Data = pbytes.Data
|
||||
pstring.Len = pbytes.Cap
|
||||
return sa
|
||||
}
|
||||
|
||||
// NewString copies a byte slice into the arena and returns it as a string.
|
||||
// If the arena is full, it returns a traditional go string.
|
||||
func (sa *StringArena) NewString(b []byte) string {
|
||||
if len(b) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(sa.buf)+len(b) > cap(sa.buf) {
|
||||
return string(b)
|
||||
}
|
||||
start := len(sa.buf)
|
||||
sa.buf = append(sa.buf, b...)
|
||||
return sa.str[start : start+len(b)]
|
||||
}
|
||||
|
||||
// SpaceLeft returns the amount of space left in the arena.
|
||||
func (sa *StringArena) SpaceLeft() int {
|
||||
return cap(sa.buf) - len(sa.buf)
|
||||
}
|
||||
|
||||
// String force casts a []byte to a string.
|
||||
// USE AT YOUR OWN RISK
|
||||
func String(b []byte) (s string) {
|
||||
if len(b) == 0 {
|
||||
return ""
|
||||
}
|
||||
pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
|
||||
pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
|
||||
pstring.Data = pbytes.Data
|
||||
pstring.Len = pbytes.Len
|
||||
return
|
||||
}
|
||||
|
||||
// StringPointer returns &s[0], which is not allowed in go
|
||||
func StringPointer(s string) unsafe.Pointer {
|
||||
pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
|
||||
return unsafe.Pointer(pstring.Data)
|
||||
}
|
2734
vendor/github.com/xwb1989/sqlparser/dependency/querypb/query.pb.go
generated
vendored
Normal file
2734
vendor/github.com/xwb1989/sqlparser/dependency/querypb/query.pb.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
266
vendor/github.com/xwb1989/sqlparser/dependency/sqltypes/bind_variables.go
generated
vendored
Normal file
266
vendor/github.com/xwb1989/sqlparser/dependency/sqltypes/bind_variables.go
generated
vendored
Normal file
|
@ -0,0 +1,266 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
// NullBindVariable is a bindvar with NULL value.
|
||||
var NullBindVariable = &querypb.BindVariable{Type: querypb.Type_NULL_TYPE}
|
||||
|
||||
// ValueToProto converts Value to a *querypb.Value.
|
||||
func ValueToProto(v Value) *querypb.Value {
|
||||
return &querypb.Value{Type: v.typ, Value: v.val}
|
||||
}
|
||||
|
||||
// ProtoToValue converts a *querypb.Value to a Value.
|
||||
func ProtoToValue(v *querypb.Value) Value {
|
||||
return MakeTrusted(v.Type, v.Value)
|
||||
}
|
||||
|
||||
// BuildBindVariables builds a map[string]*querypb.BindVariable from a map[string]interface{}.
|
||||
func BuildBindVariables(in map[string]interface{}) (map[string]*querypb.BindVariable, error) {
|
||||
if len(in) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
out := make(map[string]*querypb.BindVariable, len(in))
|
||||
for k, v := range in {
|
||||
bv, err := BuildBindVariable(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %v", k, err)
|
||||
}
|
||||
out[k] = bv
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Int32BindVariable converts an int32 to a bind var.
|
||||
func Int32BindVariable(v int32) *querypb.BindVariable {
|
||||
return ValueBindVariable(NewInt32(v))
|
||||
}
|
||||
|
||||
// Int64BindVariable converts an int64 to a bind var.
|
||||
func Int64BindVariable(v int64) *querypb.BindVariable {
|
||||
return ValueBindVariable(NewInt64(v))
|
||||
}
|
||||
|
||||
// Uint64BindVariable converts a uint64 to a bind var.
|
||||
func Uint64BindVariable(v uint64) *querypb.BindVariable {
|
||||
return ValueBindVariable(NewUint64(v))
|
||||
}
|
||||
|
||||
// Float64BindVariable converts a float64 to a bind var.
|
||||
func Float64BindVariable(v float64) *querypb.BindVariable {
|
||||
return ValueBindVariable(NewFloat64(v))
|
||||
}
|
||||
|
||||
// StringBindVariable converts a string to a bind var.
|
||||
func StringBindVariable(v string) *querypb.BindVariable {
|
||||
return ValueBindVariable(NewVarChar(v))
|
||||
}
|
||||
|
||||
// BytesBindVariable converts a []byte to a bind var.
|
||||
func BytesBindVariable(v []byte) *querypb.BindVariable {
|
||||
return &querypb.BindVariable{Type: VarBinary, Value: v}
|
||||
}
|
||||
|
||||
// ValueBindVariable converts a Value to a bind var.
|
||||
func ValueBindVariable(v Value) *querypb.BindVariable {
|
||||
return &querypb.BindVariable{Type: v.typ, Value: v.val}
|
||||
}
|
||||
|
||||
// BuildBindVariable builds a *querypb.BindVariable from a valid input type.
|
||||
func BuildBindVariable(v interface{}) (*querypb.BindVariable, error) {
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
return StringBindVariable(v), nil
|
||||
case []byte:
|
||||
return BytesBindVariable(v), nil
|
||||
case int:
|
||||
return &querypb.BindVariable{
|
||||
Type: querypb.Type_INT64,
|
||||
Value: strconv.AppendInt(nil, int64(v), 10),
|
||||
}, nil
|
||||
case int64:
|
||||
return Int64BindVariable(v), nil
|
||||
case uint64:
|
||||
return Uint64BindVariable(v), nil
|
||||
case float64:
|
||||
return Float64BindVariable(v), nil
|
||||
case nil:
|
||||
return NullBindVariable, nil
|
||||
case Value:
|
||||
return ValueBindVariable(v), nil
|
||||
case *querypb.BindVariable:
|
||||
return v, nil
|
||||
case []interface{}:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
lbv, err := BuildBindVariable(lv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values[i].Type = lbv.Type
|
||||
values[i].Value = lbv.Value
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
case []string:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
values[i].Type = querypb.Type_VARCHAR
|
||||
values[i].Value = []byte(lv)
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
case [][]byte:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
values[i].Type = querypb.Type_VARBINARY
|
||||
values[i].Value = lv
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
case []int:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
values[i].Type = querypb.Type_INT64
|
||||
values[i].Value = strconv.AppendInt(nil, int64(lv), 10)
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
case []int64:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
values[i].Type = querypb.Type_INT64
|
||||
values[i].Value = strconv.AppendInt(nil, lv, 10)
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
case []uint64:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
values[i].Type = querypb.Type_UINT64
|
||||
values[i].Value = strconv.AppendUint(nil, lv, 10)
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
case []float64:
|
||||
bv := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
Values: make([]*querypb.Value, len(v)),
|
||||
}
|
||||
values := make([]querypb.Value, len(v))
|
||||
for i, lv := range v {
|
||||
values[i].Type = querypb.Type_FLOAT64
|
||||
values[i].Value = strconv.AppendFloat(nil, lv, 'g', -1, 64)
|
||||
bv.Values[i] = &values[i]
|
||||
}
|
||||
return bv, nil
|
||||
}
|
||||
return nil, fmt.Errorf("type %T not supported as bind var: %v", v, v)
|
||||
}
|
||||
|
||||
// ValidateBindVariables validates a map[string]*querypb.BindVariable.
|
||||
func ValidateBindVariables(bv map[string]*querypb.BindVariable) error {
|
||||
for k, v := range bv {
|
||||
if err := ValidateBindVariable(v); err != nil {
|
||||
return fmt.Errorf("%s: %v", k, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateBindVariable returns an error if the bind variable has inconsistent
|
||||
// fields.
|
||||
func ValidateBindVariable(bv *querypb.BindVariable) error {
|
||||
if bv == nil {
|
||||
return errors.New("bind variable is nil")
|
||||
}
|
||||
|
||||
if bv.Type == querypb.Type_TUPLE {
|
||||
if len(bv.Values) == 0 {
|
||||
return errors.New("empty tuple is not allowed")
|
||||
}
|
||||
for _, val := range bv.Values {
|
||||
if val.Type == querypb.Type_TUPLE {
|
||||
return errors.New("tuple not allowed inside another tuple")
|
||||
}
|
||||
if err := ValidateBindVariable(&querypb.BindVariable{Type: val.Type, Value: val.Value}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If NewValue succeeds, the value is valid.
|
||||
_, err := NewValue(bv.Type, bv.Value)
|
||||
return err
|
||||
}
|
||||
|
||||
// BindVariableToValue converts a bind var into a Value.
|
||||
func BindVariableToValue(bv *querypb.BindVariable) (Value, error) {
|
||||
if bv.Type == querypb.Type_TUPLE {
|
||||
return NULL, errors.New("cannot convert a TUPLE bind var into a value")
|
||||
}
|
||||
return MakeTrusted(bv.Type, bv.Value), nil
|
||||
}
|
||||
|
||||
// BindVariablesEqual compares two maps of bind variables.
|
||||
func BindVariablesEqual(x, y map[string]*querypb.BindVariable) bool {
|
||||
return reflect.DeepEqual(&querypb.BoundQuery{BindVariables: x}, &querypb.BoundQuery{BindVariables: y})
|
||||
}
|
||||
|
||||
// CopyBindVariables returns a shallow-copy of the given bindVariables map.
|
||||
func CopyBindVariables(bindVariables map[string]*querypb.BindVariable) map[string]*querypb.BindVariable {
|
||||
result := make(map[string]*querypb.BindVariable, len(bindVariables))
|
||||
for key, value := range bindVariables {
|
||||
result[key] = value
|
||||
}
|
||||
return result
|
||||
}
|
259
vendor/github.com/xwb1989/sqlparser/dependency/sqltypes/plan_value.go
generated
vendored
Normal file
259
vendor/github.com/xwb1989/sqlparser/dependency/sqltypes/plan_value.go
generated
vendored
Normal file
|
@ -0,0 +1,259 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
// PlanValue represents a value or a list of values for
|
||||
// a column that will later be resolved using bind vars and used
|
||||
// to perform plan actions like generating the final query or
|
||||
// deciding on a route.
|
||||
//
|
||||
// Plan values are typically used as a slice ([]planValue)
|
||||
// where each entry is for one column. For situations where
|
||||
// the required output is a list of rows (like in the case
|
||||
// of multi-value inserts), the representation is pivoted.
|
||||
// For example, a statement like this:
|
||||
// INSERT INTO t VALUES (1, 2), (3, 4)
|
||||
// will be represented as follows:
|
||||
// []PlanValue{
|
||||
// Values: {1, 3},
|
||||
// Values: {2, 4},
|
||||
// }
|
||||
//
|
||||
// For WHERE clause items that contain a combination of
|
||||
// equality expressions and IN clauses like this:
|
||||
// WHERE pk1 = 1 AND pk2 IN (2, 3, 4)
|
||||
// The plan values will be represented as follows:
|
||||
// []PlanValue{
|
||||
// Value: 1,
|
||||
// Values: {2, 3, 4},
|
||||
// }
|
||||
// When converted into rows, columns with single values
|
||||
// are replicated as the same for all rows:
|
||||
// [][]Value{
|
||||
// {1, 2},
|
||||
// {1, 3},
|
||||
// {1, 4},
|
||||
// }
|
||||
type PlanValue struct {
|
||||
Key string
|
||||
Value Value
|
||||
ListKey string
|
||||
Values []PlanValue
|
||||
}
|
||||
|
||||
// IsNull returns true if the PlanValue is NULL.
|
||||
func (pv PlanValue) IsNull() bool {
|
||||
return pv.Key == "" && pv.Value.IsNull() && pv.ListKey == "" && pv.Values == nil
|
||||
}
|
||||
|
||||
// IsList returns true if the PlanValue is a list.
|
||||
func (pv PlanValue) IsList() bool {
|
||||
return pv.ListKey != "" || pv.Values != nil
|
||||
}
|
||||
|
||||
// ResolveValue resolves a PlanValue as a single value based on the supplied bindvars.
|
||||
func (pv PlanValue) ResolveValue(bindVars map[string]*querypb.BindVariable) (Value, error) {
|
||||
switch {
|
||||
case pv.Key != "":
|
||||
bv, err := pv.lookupValue(bindVars)
|
||||
if err != nil {
|
||||
return NULL, err
|
||||
}
|
||||
return MakeTrusted(bv.Type, bv.Value), nil
|
||||
case !pv.Value.IsNull():
|
||||
return pv.Value, nil
|
||||
case pv.ListKey != "" || pv.Values != nil:
|
||||
// This code is unreachable because the parser does not allow
|
||||
// multi-value constructs where a single value is expected.
|
||||
return NULL, errors.New("a list was supplied where a single value was expected")
|
||||
}
|
||||
return NULL, nil
|
||||
}
|
||||
|
||||
func (pv PlanValue) lookupValue(bindVars map[string]*querypb.BindVariable) (*querypb.BindVariable, error) {
|
||||
bv, ok := bindVars[pv.Key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing bind var %s", pv.Key)
|
||||
}
|
||||
if bv.Type == querypb.Type_TUPLE {
|
||||
return nil, fmt.Errorf("TUPLE was supplied for single value bind var %s", pv.ListKey)
|
||||
}
|
||||
return bv, nil
|
||||
}
|
||||
|
||||
// ResolveList resolves a PlanValue as a list of values based on the supplied bindvars.
|
||||
func (pv PlanValue) ResolveList(bindVars map[string]*querypb.BindVariable) ([]Value, error) {
|
||||
switch {
|
||||
case pv.ListKey != "":
|
||||
bv, err := pv.lookupList(bindVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values := make([]Value, 0, len(bv.Values))
|
||||
for _, val := range bv.Values {
|
||||
values = append(values, MakeTrusted(val.Type, val.Value))
|
||||
}
|
||||
return values, nil
|
||||
case pv.Values != nil:
|
||||
values := make([]Value, 0, len(pv.Values))
|
||||
for _, val := range pv.Values {
|
||||
v, err := val.ResolveValue(bindVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
// This code is unreachable because the parser does not allow
|
||||
// single value constructs where multiple values are expected.
|
||||
return nil, errors.New("a single value was supplied where a list was expected")
|
||||
}
|
||||
|
||||
func (pv PlanValue) lookupList(bindVars map[string]*querypb.BindVariable) (*querypb.BindVariable, error) {
|
||||
bv, ok := bindVars[pv.ListKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing bind var %s", pv.ListKey)
|
||||
}
|
||||
if bv.Type != querypb.Type_TUPLE {
|
||||
return nil, fmt.Errorf("single value was supplied for TUPLE bind var %s", pv.ListKey)
|
||||
}
|
||||
return bv, nil
|
||||
}
|
||||
|
||||
// MarshalJSON should be used only for testing.
|
||||
func (pv PlanValue) MarshalJSON() ([]byte, error) {
|
||||
switch {
|
||||
case pv.Key != "":
|
||||
return json.Marshal(":" + pv.Key)
|
||||
case !pv.Value.IsNull():
|
||||
if pv.Value.IsIntegral() {
|
||||
return pv.Value.ToBytes(), nil
|
||||
}
|
||||
return json.Marshal(pv.Value.ToString())
|
||||
case pv.ListKey != "":
|
||||
return json.Marshal("::" + pv.ListKey)
|
||||
case pv.Values != nil:
|
||||
return json.Marshal(pv.Values)
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
func rowCount(pvs []PlanValue, bindVars map[string]*querypb.BindVariable) (int, error) {
|
||||
count := -1
|
||||
setCount := func(l int) error {
|
||||
switch count {
|
||||
case -1:
|
||||
count = l
|
||||
return nil
|
||||
case l:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("mismatch in number of column values")
|
||||
}
|
||||
}
|
||||
|
||||
for _, pv := range pvs {
|
||||
switch {
|
||||
case pv.Key != "" || !pv.Value.IsNull():
|
||||
continue
|
||||
case pv.Values != nil:
|
||||
if err := setCount(len(pv.Values)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case pv.ListKey != "":
|
||||
bv, err := pv.lookupList(bindVars)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := setCount(len(bv.Values)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == -1 {
|
||||
// If there were no lists inside, it was a single row.
|
||||
// Note that count can never be 0 because there is enough
|
||||
// protection at the top level: list bind vars must have
|
||||
// at least one value (enforced by vtgate), and AST lists
|
||||
// must have at least one value (enforced by the parser).
|
||||
// Also lists created internally after vtgate validation
|
||||
// ensure at least one value.
|
||||
// TODO(sougou): verify and change API to enforce this.
|
||||
return 1, nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ResolveRows resolves a []PlanValue as rows based on the supplied bindvars.
|
||||
func ResolveRows(pvs []PlanValue, bindVars map[string]*querypb.BindVariable) ([][]Value, error) {
|
||||
count, err := rowCount(pvs, bindVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Allocate the rows.
|
||||
rows := make([][]Value, count)
|
||||
for i := range rows {
|
||||
rows[i] = make([]Value, len(pvs))
|
||||
}
|
||||
|
||||
// Using j becasue we're resolving by columns.
|
||||
for j, pv := range pvs {
|
||||
switch {
|
||||
case pv.Key != "":
|
||||
bv, err := pv.lookupValue(bindVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range rows {
|
||||
rows[i][j] = MakeTrusted(bv.Type, bv.Value)
|
||||
}
|
||||
case !pv.Value.IsNull():
|
||||
for i := range rows {
|
||||
rows[i][j] = pv.Value
|
||||
}
|
||||
case pv.ListKey != "":
|
||||
bv, err := pv.lookupList(bindVars)
|
||||
if err != nil {
|
||||
// This code is unreachable because pvRowCount already checks this.
|
||||
return nil, err
|
||||
}
|
||||
for i := range rows {
|
||||
rows[i][j] = MakeTrusted(bv.Values[i].Type, bv.Values[i].Value)
|
||||
}
|
||||
case pv.Values != nil:
|
||||
for i := range rows {
|
||||
rows[i][j], err = pv.Values[i].ResolveValue(bindVars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// default case is a NULL value, which the row values are already initialized to.
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
154
vendor/github.com/xwb1989/sqlparser/dependency/sqltypes/testing.go
generated
vendored
Normal file
154
vendor/github.com/xwb1989/sqlparser/dependency/sqltypes/testing.go
generated
vendored
Normal file
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
querypb "github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
// Functions in this file should only be used for testing.
|
||||
// This is an experiment to see if test code bloat can be
|
||||
// reduced and readability improved.
|
||||
|
||||
/*
|
||||
// MakeTestFields builds a []*querypb.Field for testing.
|
||||
// fields := sqltypes.MakeTestFields(
|
||||
// "a|b",
|
||||
// "int64|varchar",
|
||||
// )
|
||||
// The field types are as defined in querypb and are case
|
||||
// insensitive. Column delimiters must be used only to sepearate
|
||||
// strings and not at the beginning or the end.
|
||||
func MakeTestFields(names, types string) []*querypb.Field {
|
||||
n := split(names)
|
||||
t := split(types)
|
||||
var fields []*querypb.Field
|
||||
for i := range n {
|
||||
fields = append(fields, &querypb.Field{
|
||||
Name: n[i],
|
||||
Type: querypb.Type(querypb.Type_value[strings.ToUpper(t[i])]),
|
||||
})
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// MakeTestResult builds a *sqltypes.Result object for testing.
|
||||
// result := sqltypes.MakeTestResult(
|
||||
// fields,
|
||||
// " 1|a",
|
||||
// "10|abcd",
|
||||
// )
|
||||
// The field type values are set as the types for the rows built.
|
||||
// Spaces are trimmed from row values. "null" is treated as NULL.
|
||||
func MakeTestResult(fields []*querypb.Field, rows ...string) *Result {
|
||||
result := &Result{
|
||||
Fields: fields,
|
||||
}
|
||||
if len(rows) > 0 {
|
||||
result.Rows = make([][]Value, len(rows))
|
||||
}
|
||||
for i, row := range rows {
|
||||
result.Rows[i] = make([]Value, len(fields))
|
||||
for j, col := range split(row) {
|
||||
if col == "null" {
|
||||
continue
|
||||
}
|
||||
result.Rows[i][j] = MakeTrusted(fields[j].Type, []byte(col))
|
||||
}
|
||||
}
|
||||
result.RowsAffected = uint64(len(result.Rows))
|
||||
return result
|
||||
}
|
||||
|
||||
// MakeTestStreamingResults builds a list of results for streaming.
|
||||
// results := sqltypes.MakeStreamingResults(
|
||||
// fields,
|
||||
// "1|a",
|
||||
// "2|b",
|
||||
// "---",
|
||||
// "c|c",
|
||||
// )
|
||||
// The first result contains only the fields. Subsequent results
|
||||
// are built using the field types. Every input that starts with a "-"
|
||||
// is treated as streaming delimiter for one result. A final
|
||||
// delimiter must not be supplied.
|
||||
func MakeTestStreamingResults(fields []*querypb.Field, rows ...string) []*Result {
|
||||
var results []*Result
|
||||
results = append(results, &Result{Fields: fields})
|
||||
start := 0
|
||||
cur := 0
|
||||
// Add a final streaming delimiter to simplify the loop below.
|
||||
rows = append(rows, "-")
|
||||
for cur < len(rows) {
|
||||
if rows[cur][0] != '-' {
|
||||
cur++
|
||||
continue
|
||||
}
|
||||
result := MakeTestResult(fields, rows[start:cur]...)
|
||||
result.Fields = nil
|
||||
result.RowsAffected = 0
|
||||
results = append(results, result)
|
||||
start = cur + 1
|
||||
cur = start
|
||||
}
|
||||
return results
|
||||
}
|
||||
*/
|
||||
|
||||
// TestBindVariable makes a *querypb.BindVariable from
|
||||
// an interface{}.It panics on invalid input.
|
||||
// This function should only be used for testing.
|
||||
func TestBindVariable(v interface{}) *querypb.BindVariable {
|
||||
if v == nil {
|
||||
return NullBindVariable
|
||||
}
|
||||
bv, err := BuildBindVariable(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bv
|
||||
}
|
||||
|
||||
// TestValue builds a Value from typ and val.
|
||||
// This function should only be used for testing.
|
||||
func TestValue(typ querypb.Type, val string) Value {
|
||||
return MakeTrusted(typ, []byte(val))
|
||||
}
|
||||
|
||||
/*
|
||||
// PrintResults prints []*Results into a string.
|
||||
// This function should only be used for testing.
|
||||
func PrintResults(results []*Result) string {
|
||||
b := new(bytes.Buffer)
|
||||
for i, r := range results {
|
||||
if i == 0 {
|
||||
fmt.Fprintf(b, "%v", r)
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(b, ", %v", r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func split(str string) []string {
|
||||
splits := strings.Split(str, "|")
|
||||
for i, v := range splits {
|
||||
splits[i] = strings.TrimSpace(v)
|
||||
}
|
||||
return splits
|
||||
}
|
||||
*/
|
|
@ -0,0 +1,288 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
// This file provides wrappers and support
|
||||
// functions for querypb.Type.
|
||||
|
||||
// These bit flags can be used to query on the
|
||||
// common properties of types.
|
||||
const (
|
||||
flagIsIntegral = int(querypb.Flag_ISINTEGRAL)
|
||||
flagIsUnsigned = int(querypb.Flag_ISUNSIGNED)
|
||||
flagIsFloat = int(querypb.Flag_ISFLOAT)
|
||||
flagIsQuoted = int(querypb.Flag_ISQUOTED)
|
||||
flagIsText = int(querypb.Flag_ISTEXT)
|
||||
flagIsBinary = int(querypb.Flag_ISBINARY)
|
||||
)
|
||||
|
||||
// IsIntegral returns true if querypb.Type is an integral
|
||||
// (signed/unsigned) that can be represented using
|
||||
// up to 64 binary bits.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsIntegral(t querypb.Type) bool {
|
||||
return int(t)&flagIsIntegral == flagIsIntegral
|
||||
}
|
||||
|
||||
// IsSigned returns true if querypb.Type is a signed integral.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsSigned(t querypb.Type) bool {
|
||||
return int(t)&(flagIsIntegral|flagIsUnsigned) == flagIsIntegral
|
||||
}
|
||||
|
||||
// IsUnsigned returns true if querypb.Type is an unsigned integral.
|
||||
// Caution: this is not the same as !IsSigned.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsUnsigned(t querypb.Type) bool {
|
||||
return int(t)&(flagIsIntegral|flagIsUnsigned) == flagIsIntegral|flagIsUnsigned
|
||||
}
|
||||
|
||||
// IsFloat returns true is querypb.Type is a floating point.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsFloat(t querypb.Type) bool {
|
||||
return int(t)&flagIsFloat == flagIsFloat
|
||||
}
|
||||
|
||||
// IsQuoted returns true if querypb.Type is a quoted text or binary.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsQuoted(t querypb.Type) bool {
|
||||
return int(t)&flagIsQuoted == flagIsQuoted
|
||||
}
|
||||
|
||||
// IsText returns true if querypb.Type is a text.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsText(t querypb.Type) bool {
|
||||
return int(t)&flagIsText == flagIsText
|
||||
}
|
||||
|
||||
// IsBinary returns true if querypb.Type is a binary.
|
||||
// If you have a Value object, use its member function.
|
||||
func IsBinary(t querypb.Type) bool {
|
||||
return int(t)&flagIsBinary == flagIsBinary
|
||||
}
|
||||
|
||||
// isNumber returns true if the type is any type of number.
|
||||
func isNumber(t querypb.Type) bool {
|
||||
return IsIntegral(t) || IsFloat(t) || t == Decimal
|
||||
}
|
||||
|
||||
// Vitess data types. These are idiomatically
|
||||
// named synonyms for the querypb.Type values.
|
||||
// Although these constants are interchangeable,
|
||||
// they should be treated as different from querypb.Type.
|
||||
// Use the synonyms only to refer to the type in Value.
|
||||
// For proto variables, use the querypb.Type constants
|
||||
// instead.
|
||||
// The following conditions are non-overlapping
|
||||
// and cover all types: IsSigned(), IsUnsigned(),
|
||||
// IsFloat(), IsQuoted(), Null, Decimal, Expression.
|
||||
// Also, IsIntegral() == (IsSigned()||IsUnsigned()).
|
||||
// TestCategory needs to be updated accordingly if
|
||||
// you add a new type.
|
||||
// If IsBinary or IsText is true, then IsQuoted is
|
||||
// also true. But there are IsQuoted types that are
|
||||
// neither binary or text.
|
||||
// querypb.Type_TUPLE is not included in this list
|
||||
// because it's not a valid Value type.
|
||||
// TODO(sougou): provide a categorization function
|
||||
// that returns enums, which will allow for cleaner
|
||||
// switch statements for those who want to cover types
|
||||
// by their category.
|
||||
const (
|
||||
Null = querypb.Type_NULL_TYPE
|
||||
Int8 = querypb.Type_INT8
|
||||
Uint8 = querypb.Type_UINT8
|
||||
Int16 = querypb.Type_INT16
|
||||
Uint16 = querypb.Type_UINT16
|
||||
Int24 = querypb.Type_INT24
|
||||
Uint24 = querypb.Type_UINT24
|
||||
Int32 = querypb.Type_INT32
|
||||
Uint32 = querypb.Type_UINT32
|
||||
Int64 = querypb.Type_INT64
|
||||
Uint64 = querypb.Type_UINT64
|
||||
Float32 = querypb.Type_FLOAT32
|
||||
Float64 = querypb.Type_FLOAT64
|
||||
Timestamp = querypb.Type_TIMESTAMP
|
||||
Date = querypb.Type_DATE
|
||||
Time = querypb.Type_TIME
|
||||
Datetime = querypb.Type_DATETIME
|
||||
Year = querypb.Type_YEAR
|
||||
Decimal = querypb.Type_DECIMAL
|
||||
Text = querypb.Type_TEXT
|
||||
Blob = querypb.Type_BLOB
|
||||
VarChar = querypb.Type_VARCHAR
|
||||
VarBinary = querypb.Type_VARBINARY
|
||||
Char = querypb.Type_CHAR
|
||||
Binary = querypb.Type_BINARY
|
||||
Bit = querypb.Type_BIT
|
||||
Enum = querypb.Type_ENUM
|
||||
Set = querypb.Type_SET
|
||||
Geometry = querypb.Type_GEOMETRY
|
||||
TypeJSON = querypb.Type_JSON
|
||||
Expression = querypb.Type_EXPRESSION
|
||||
)
|
||||
|
||||
// bit-shift the mysql flags by two byte so we
|
||||
// can merge them with the mysql or vitess types.
|
||||
const (
|
||||
mysqlUnsigned = 32
|
||||
mysqlBinary = 128
|
||||
mysqlEnum = 256
|
||||
mysqlSet = 2048
|
||||
)
|
||||
|
||||
// If you add to this map, make sure you add a test case
|
||||
// in tabletserver/endtoend.
|
||||
var mysqlToType = map[int64]querypb.Type{
|
||||
1: Int8,
|
||||
2: Int16,
|
||||
3: Int32,
|
||||
4: Float32,
|
||||
5: Float64,
|
||||
6: Null,
|
||||
7: Timestamp,
|
||||
8: Int64,
|
||||
9: Int24,
|
||||
10: Date,
|
||||
11: Time,
|
||||
12: Datetime,
|
||||
13: Year,
|
||||
16: Bit,
|
||||
245: TypeJSON,
|
||||
246: Decimal,
|
||||
249: Text,
|
||||
250: Text,
|
||||
251: Text,
|
||||
252: Text,
|
||||
253: VarChar,
|
||||
254: Char,
|
||||
255: Geometry,
|
||||
}
|
||||
|
||||
// modifyType modifies the vitess type based on the
|
||||
// mysql flag. The function checks specific flags based
|
||||
// on the type. This allows us to ignore stray flags
|
||||
// that MySQL occasionally sets.
|
||||
func modifyType(typ querypb.Type, flags int64) querypb.Type {
|
||||
switch typ {
|
||||
case Int8:
|
||||
if flags&mysqlUnsigned != 0 {
|
||||
return Uint8
|
||||
}
|
||||
return Int8
|
||||
case Int16:
|
||||
if flags&mysqlUnsigned != 0 {
|
||||
return Uint16
|
||||
}
|
||||
return Int16
|
||||
case Int32:
|
||||
if flags&mysqlUnsigned != 0 {
|
||||
return Uint32
|
||||
}
|
||||
return Int32
|
||||
case Int64:
|
||||
if flags&mysqlUnsigned != 0 {
|
||||
return Uint64
|
||||
}
|
||||
return Int64
|
||||
case Int24:
|
||||
if flags&mysqlUnsigned != 0 {
|
||||
return Uint24
|
||||
}
|
||||
return Int24
|
||||
case Text:
|
||||
if flags&mysqlBinary != 0 {
|
||||
return Blob
|
||||
}
|
||||
return Text
|
||||
case VarChar:
|
||||
if flags&mysqlBinary != 0 {
|
||||
return VarBinary
|
||||
}
|
||||
return VarChar
|
||||
case Char:
|
||||
if flags&mysqlBinary != 0 {
|
||||
return Binary
|
||||
}
|
||||
if flags&mysqlEnum != 0 {
|
||||
return Enum
|
||||
}
|
||||
if flags&mysqlSet != 0 {
|
||||
return Set
|
||||
}
|
||||
return Char
|
||||
}
|
||||
return typ
|
||||
}
|
||||
|
||||
// MySQLToType computes the vitess type from mysql type and flags.
|
||||
func MySQLToType(mysqlType, flags int64) (typ querypb.Type, err error) {
|
||||
result, ok := mysqlToType[mysqlType]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unsupported type: %d", mysqlType)
|
||||
}
|
||||
return modifyType(result, flags), nil
|
||||
}
|
||||
|
||||
// typeToMySQL is the reverse of mysqlToType.
|
||||
var typeToMySQL = map[querypb.Type]struct {
|
||||
typ int64
|
||||
flags int64
|
||||
}{
|
||||
Int8: {typ: 1},
|
||||
Uint8: {typ: 1, flags: mysqlUnsigned},
|
||||
Int16: {typ: 2},
|
||||
Uint16: {typ: 2, flags: mysqlUnsigned},
|
||||
Int32: {typ: 3},
|
||||
Uint32: {typ: 3, flags: mysqlUnsigned},
|
||||
Float32: {typ: 4},
|
||||
Float64: {typ: 5},
|
||||
Null: {typ: 6, flags: mysqlBinary},
|
||||
Timestamp: {typ: 7},
|
||||
Int64: {typ: 8},
|
||||
Uint64: {typ: 8, flags: mysqlUnsigned},
|
||||
Int24: {typ: 9},
|
||||
Uint24: {typ: 9, flags: mysqlUnsigned},
|
||||
Date: {typ: 10, flags: mysqlBinary},
|
||||
Time: {typ: 11, flags: mysqlBinary},
|
||||
Datetime: {typ: 12, flags: mysqlBinary},
|
||||
Year: {typ: 13, flags: mysqlUnsigned},
|
||||
Bit: {typ: 16, flags: mysqlUnsigned},
|
||||
TypeJSON: {typ: 245},
|
||||
Decimal: {typ: 246},
|
||||
Text: {typ: 252},
|
||||
Blob: {typ: 252, flags: mysqlBinary},
|
||||
VarChar: {typ: 253},
|
||||
VarBinary: {typ: 253, flags: mysqlBinary},
|
||||
Char: {typ: 254},
|
||||
Binary: {typ: 254, flags: mysqlBinary},
|
||||
Enum: {typ: 254, flags: mysqlEnum},
|
||||
Set: {typ: 254, flags: mysqlSet},
|
||||
Geometry: {typ: 255},
|
||||
}
|
||||
|
||||
// TypeToMySQL returns the equivalent mysql type and flag for a vitess type.
|
||||
func TypeToMySQL(typ querypb.Type) (mysqlType, flags int64) {
|
||||
val := typeToMySQL[typ]
|
||||
return val.typ, val.flags
|
||||
}
|
|
@ -0,0 +1,376 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package sqltypes implements interfaces and types that represent SQL values.
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/bytes2"
|
||||
"github.com/xwb1989/sqlparser/dependency/hack"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
var (
|
||||
// NULL represents the NULL value.
|
||||
NULL = Value{}
|
||||
|
||||
// DontEscape tells you if a character should not be escaped.
|
||||
DontEscape = byte(255)
|
||||
|
||||
nullstr = []byte("null")
|
||||
)
|
||||
|
||||
// BinWriter interface is used for encoding values.
|
||||
// Types like bytes.Buffer conform to this interface.
|
||||
// We expect the writer objects to be in-memory buffers.
|
||||
// So, we don't expect the write operations to fail.
|
||||
type BinWriter interface {
|
||||
Write([]byte) (int, error)
|
||||
}
|
||||
|
||||
// Value can store any SQL value. If the value represents
|
||||
// an integral type, the bytes are always stored as a cannonical
|
||||
// representation that matches how MySQL returns such values.
|
||||
type Value struct {
|
||||
typ querypb.Type
|
||||
val []byte
|
||||
}
|
||||
|
||||
// NewValue builds a Value using typ and val. If the value and typ
|
||||
// don't match, it returns an error.
|
||||
func NewValue(typ querypb.Type, val []byte) (v Value, err error) {
|
||||
switch {
|
||||
case IsSigned(typ):
|
||||
if _, err := strconv.ParseInt(string(val), 0, 64); err != nil {
|
||||
return NULL, err
|
||||
}
|
||||
return MakeTrusted(typ, val), nil
|
||||
case IsUnsigned(typ):
|
||||
if _, err := strconv.ParseUint(string(val), 0, 64); err != nil {
|
||||
return NULL, err
|
||||
}
|
||||
return MakeTrusted(typ, val), nil
|
||||
case IsFloat(typ) || typ == Decimal:
|
||||
if _, err := strconv.ParseFloat(string(val), 64); err != nil {
|
||||
return NULL, err
|
||||
}
|
||||
return MakeTrusted(typ, val), nil
|
||||
case IsQuoted(typ) || typ == Null:
|
||||
return MakeTrusted(typ, val), nil
|
||||
}
|
||||
// All other types are unsafe or invalid.
|
||||
return NULL, fmt.Errorf("invalid type specified for MakeValue: %v", typ)
|
||||
}
|
||||
|
||||
// MakeTrusted makes a new Value based on the type.
|
||||
// This function should only be used if you know the value
|
||||
// and type conform to the rules. Every place this function is
|
||||
// called, a comment is needed that explains why it's justified.
|
||||
// Exceptions: The current package and mysql package do not need
|
||||
// comments. Other packages can also use the function to create
|
||||
// VarBinary or VarChar values.
|
||||
func MakeTrusted(typ querypb.Type, val []byte) Value {
|
||||
if typ == Null {
|
||||
return NULL
|
||||
}
|
||||
return Value{typ: typ, val: val}
|
||||
}
|
||||
|
||||
// NewInt64 builds an Int64 Value.
|
||||
func NewInt64(v int64) Value {
|
||||
return MakeTrusted(Int64, strconv.AppendInt(nil, v, 10))
|
||||
}
|
||||
|
||||
// NewInt32 builds an Int64 Value.
|
||||
func NewInt32(v int32) Value {
|
||||
return MakeTrusted(Int32, strconv.AppendInt(nil, int64(v), 10))
|
||||
}
|
||||
|
||||
// NewUint64 builds an Uint64 Value.
|
||||
func NewUint64(v uint64) Value {
|
||||
return MakeTrusted(Uint64, strconv.AppendUint(nil, v, 10))
|
||||
}
|
||||
|
||||
// NewFloat64 builds an Float64 Value.
|
||||
func NewFloat64(v float64) Value {
|
||||
return MakeTrusted(Float64, strconv.AppendFloat(nil, v, 'g', -1, 64))
|
||||
}
|
||||
|
||||
// NewVarChar builds a VarChar Value.
|
||||
func NewVarChar(v string) Value {
|
||||
return MakeTrusted(VarChar, []byte(v))
|
||||
}
|
||||
|
||||
// NewVarBinary builds a VarBinary Value.
|
||||
// The input is a string because it's the most common use case.
|
||||
func NewVarBinary(v string) Value {
|
||||
return MakeTrusted(VarBinary, []byte(v))
|
||||
}
|
||||
|
||||
// NewIntegral builds an integral type from a string representaion.
|
||||
// The type will be Int64 or Uint64. Int64 will be preferred where possible.
|
||||
func NewIntegral(val string) (n Value, err error) {
|
||||
signed, err := strconv.ParseInt(val, 0, 64)
|
||||
if err == nil {
|
||||
return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil
|
||||
}
|
||||
unsigned, err := strconv.ParseUint(val, 0, 64)
|
||||
if err != nil {
|
||||
return Value{}, err
|
||||
}
|
||||
return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil
|
||||
}
|
||||
|
||||
// InterfaceToValue builds a value from a go type.
|
||||
// Supported types are nil, int64, uint64, float64,
|
||||
// string and []byte.
|
||||
// This function is deprecated. Use the type-specific
|
||||
// functions instead.
|
||||
func InterfaceToValue(goval interface{}) (Value, error) {
|
||||
switch goval := goval.(type) {
|
||||
case nil:
|
||||
return NULL, nil
|
||||
case []byte:
|
||||
return MakeTrusted(VarBinary, goval), nil
|
||||
case int64:
|
||||
return NewInt64(goval), nil
|
||||
case uint64:
|
||||
return NewUint64(goval), nil
|
||||
case float64:
|
||||
return NewFloat64(goval), nil
|
||||
case string:
|
||||
return NewVarChar(goval), nil
|
||||
default:
|
||||
return NULL, fmt.Errorf("unexpected type %T: %v", goval, goval)
|
||||
}
|
||||
}
|
||||
|
||||
// Type returns the type of Value.
|
||||
func (v Value) Type() querypb.Type {
|
||||
return v.typ
|
||||
}
|
||||
|
||||
// Raw returns the internal represenation of the value. For newer types,
|
||||
// this may not match MySQL's representation.
|
||||
func (v Value) Raw() []byte {
|
||||
return v.val
|
||||
}
|
||||
|
||||
// ToBytes returns the value as MySQL would return it as []byte.
|
||||
// In contrast, Raw returns the internal representation of the Value, which may not
|
||||
// match MySQL's representation for newer types.
|
||||
// If the value is not convertible like in the case of Expression, it returns nil.
|
||||
func (v Value) ToBytes() []byte {
|
||||
if v.typ == Expression {
|
||||
return nil
|
||||
}
|
||||
return v.val
|
||||
}
|
||||
|
||||
// Len returns the length.
|
||||
func (v Value) Len() int {
|
||||
return len(v.val)
|
||||
}
|
||||
|
||||
// ToString returns the value as MySQL would return it as string.
|
||||
// If the value is not convertible like in the case of Expression, it returns nil.
|
||||
func (v Value) ToString() string {
|
||||
if v.typ == Expression {
|
||||
return ""
|
||||
}
|
||||
return hack.String(v.val)
|
||||
}
|
||||
|
||||
// String returns a printable version of the value.
|
||||
func (v Value) String() string {
|
||||
if v.typ == Null {
|
||||
return "NULL"
|
||||
}
|
||||
if v.IsQuoted() {
|
||||
return fmt.Sprintf("%v(%q)", v.typ, v.val)
|
||||
}
|
||||
return fmt.Sprintf("%v(%s)", v.typ, v.val)
|
||||
}
|
||||
|
||||
// EncodeSQL encodes the value into an SQL statement. Can be binary.
|
||||
func (v Value) EncodeSQL(b BinWriter) {
|
||||
switch {
|
||||
case v.typ == Null:
|
||||
b.Write(nullstr)
|
||||
case v.IsQuoted():
|
||||
encodeBytesSQL(v.val, b)
|
||||
default:
|
||||
b.Write(v.val)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeASCII encodes the value using 7-bit clean ascii bytes.
|
||||
func (v Value) EncodeASCII(b BinWriter) {
|
||||
switch {
|
||||
case v.typ == Null:
|
||||
b.Write(nullstr)
|
||||
case v.IsQuoted():
|
||||
encodeBytesASCII(v.val, b)
|
||||
default:
|
||||
b.Write(v.val)
|
||||
}
|
||||
}
|
||||
|
||||
// IsNull returns true if Value is null.
|
||||
func (v Value) IsNull() bool {
|
||||
return v.typ == Null
|
||||
}
|
||||
|
||||
// IsIntegral returns true if Value is an integral.
|
||||
func (v Value) IsIntegral() bool {
|
||||
return IsIntegral(v.typ)
|
||||
}
|
||||
|
||||
// IsSigned returns true if Value is a signed integral.
|
||||
func (v Value) IsSigned() bool {
|
||||
return IsSigned(v.typ)
|
||||
}
|
||||
|
||||
// IsUnsigned returns true if Value is an unsigned integral.
|
||||
func (v Value) IsUnsigned() bool {
|
||||
return IsUnsigned(v.typ)
|
||||
}
|
||||
|
||||
// IsFloat returns true if Value is a float.
|
||||
func (v Value) IsFloat() bool {
|
||||
return IsFloat(v.typ)
|
||||
}
|
||||
|
||||
// IsQuoted returns true if Value must be SQL-quoted.
|
||||
func (v Value) IsQuoted() bool {
|
||||
return IsQuoted(v.typ)
|
||||
}
|
||||
|
||||
// IsText returns true if Value is a collatable text.
|
||||
func (v Value) IsText() bool {
|
||||
return IsText(v.typ)
|
||||
}
|
||||
|
||||
// IsBinary returns true if Value is binary.
|
||||
func (v Value) IsBinary() bool {
|
||||
return IsBinary(v.typ)
|
||||
}
|
||||
|
||||
// MarshalJSON should only be used for testing.
|
||||
// It's not a complete implementation.
|
||||
func (v Value) MarshalJSON() ([]byte, error) {
|
||||
switch {
|
||||
case v.IsQuoted():
|
||||
return json.Marshal(v.ToString())
|
||||
case v.typ == Null:
|
||||
return nullstr, nil
|
||||
}
|
||||
return v.val, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON should only be used for testing.
|
||||
// It's not a complete implementation.
|
||||
func (v *Value) UnmarshalJSON(b []byte) error {
|
||||
if len(b) == 0 {
|
||||
return fmt.Errorf("error unmarshaling empty bytes")
|
||||
}
|
||||
var val interface{}
|
||||
var err error
|
||||
switch b[0] {
|
||||
case '-':
|
||||
var ival int64
|
||||
err = json.Unmarshal(b, &ival)
|
||||
val = ival
|
||||
case '"':
|
||||
var bval []byte
|
||||
err = json.Unmarshal(b, &bval)
|
||||
val = bval
|
||||
case 'n': // null
|
||||
err = json.Unmarshal(b, &val)
|
||||
default:
|
||||
var uval uint64
|
||||
err = json.Unmarshal(b, &uval)
|
||||
val = uval
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*v, err = InterfaceToValue(val)
|
||||
return err
|
||||
}
|
||||
|
||||
func encodeBytesSQL(val []byte, b BinWriter) {
|
||||
buf := &bytes2.Buffer{}
|
||||
buf.WriteByte('\'')
|
||||
for _, ch := range val {
|
||||
if encodedChar := SQLEncodeMap[ch]; encodedChar == DontEscape {
|
||||
buf.WriteByte(ch)
|
||||
} else {
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte(encodedChar)
|
||||
}
|
||||
}
|
||||
buf.WriteByte('\'')
|
||||
b.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
func encodeBytesASCII(val []byte, b BinWriter) {
|
||||
buf := &bytes2.Buffer{}
|
||||
buf.WriteByte('\'')
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, buf)
|
||||
encoder.Write(val)
|
||||
encoder.Close()
|
||||
buf.WriteByte('\'')
|
||||
b.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
// SQLEncodeMap specifies how to escape binary data with '\'.
|
||||
// Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html
|
||||
var SQLEncodeMap [256]byte
|
||||
|
||||
// SQLDecodeMap is the reverse of SQLEncodeMap
|
||||
var SQLDecodeMap [256]byte
|
||||
|
||||
var encodeRef = map[byte]byte{
|
||||
'\x00': '0',
|
||||
'\'': '\'',
|
||||
'"': '"',
|
||||
'\b': 'b',
|
||||
'\n': 'n',
|
||||
'\r': 'r',
|
||||
'\t': 't',
|
||||
26: 'Z', // ctl-Z
|
||||
'\\': '\\',
|
||||
}
|
||||
|
||||
func init() {
|
||||
for i := range SQLEncodeMap {
|
||||
SQLEncodeMap[i] = DontEscape
|
||||
SQLDecodeMap[i] = DontEscape
|
||||
}
|
||||
for i := range SQLEncodeMap {
|
||||
if to, ok := encodeRef[byte(i)]; ok {
|
||||
SQLEncodeMap[byte(i)] = to
|
||||
SQLDecodeMap[to] = byte(i)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
// This file contains types that are 'Encodable'.
|
||||
|
||||
// Encodable defines the interface for types that can
|
||||
// be custom-encoded into SQL.
|
||||
type Encodable interface {
|
||||
EncodeSQL(buf *bytes.Buffer)
|
||||
}
|
||||
|
||||
// InsertValues is a custom SQL encoder for the values of
|
||||
// an insert statement.
|
||||
type InsertValues [][]sqltypes.Value
|
||||
|
||||
// EncodeSQL performs the SQL encoding for InsertValues.
|
||||
func (iv InsertValues) EncodeSQL(buf *bytes.Buffer) {
|
||||
for i, rows := range iv {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
buf.WriteByte('(')
|
||||
for j, bv := range rows {
|
||||
if j != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
bv.EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// TupleEqualityList is for generating equality constraints
|
||||
// for tables that have composite primary keys.
|
||||
type TupleEqualityList struct {
|
||||
Columns []ColIdent
|
||||
Rows [][]sqltypes.Value
|
||||
}
|
||||
|
||||
// EncodeSQL generates the where clause constraints for the tuple
|
||||
// equality.
|
||||
func (tpl *TupleEqualityList) EncodeSQL(buf *bytes.Buffer) {
|
||||
if len(tpl.Columns) == 1 {
|
||||
tpl.encodeAsIn(buf)
|
||||
return
|
||||
}
|
||||
tpl.encodeAsEquality(buf)
|
||||
}
|
||||
|
||||
func (tpl *TupleEqualityList) encodeAsIn(buf *bytes.Buffer) {
|
||||
Append(buf, tpl.Columns[0])
|
||||
buf.WriteString(" in (")
|
||||
for i, r := range tpl.Rows {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
r[0].EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
|
||||
func (tpl *TupleEqualityList) encodeAsEquality(buf *bytes.Buffer) {
|
||||
for i, r := range tpl.Rows {
|
||||
if i != 0 {
|
||||
buf.WriteString(" or ")
|
||||
}
|
||||
buf.WriteString("(")
|
||||
for j, c := range tpl.Columns {
|
||||
if j != 0 {
|
||||
buf.WriteString(" and ")
|
||||
}
|
||||
Append(buf, c)
|
||||
buf.WriteString(" = ")
|
||||
r[j].EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreedto in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
// FormatImpossibleQuery creates an impossible query in a TrackedBuffer.
|
||||
// An impossible query is a modified version of a query where all selects have where clauses that are
|
||||
// impossible for mysql to resolve. This is used in the vtgate and vttablet:
|
||||
//
|
||||
// - In the vtgate it's used for joins: if the first query returns no result, then vtgate uses the impossible
|
||||
// query just to fetch field info from vttablet
|
||||
// - In the vttablet, it's just an optimization: the field info is fetched once form MySQL, cached and reused
|
||||
// for subsequent queries
|
||||
func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) {
|
||||
switch node := node.(type) {
|
||||
case *Select:
|
||||
buf.Myprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From)
|
||||
if node.GroupBy != nil {
|
||||
node.GroupBy.Format(buf)
|
||||
}
|
||||
case *Union:
|
||||
buf.Myprintf("%v %s %v", node.Left, node.Type, node.Right)
|
||||
default:
|
||||
node.Format(buf)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,224 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
)
|
||||
|
||||
// Normalize changes the statement to use bind values, and
|
||||
// updates the bind vars to those values. The supplied prefix
|
||||
// is used to generate the bind var names. The function ensures
|
||||
// that there are no collisions with existing bind vars.
|
||||
// Within Select constructs, bind vars are deduped. This allows
|
||||
// us to identify vindex equality. Otherwise, every value is
|
||||
// treated as distinct.
|
||||
func Normalize(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) {
|
||||
nz := newNormalizer(stmt, bindVars, prefix)
|
||||
_ = Walk(nz.WalkStatement, stmt)
|
||||
}
|
||||
|
||||
type normalizer struct {
|
||||
stmt Statement
|
||||
bindVars map[string]*querypb.BindVariable
|
||||
prefix string
|
||||
reserved map[string]struct{}
|
||||
counter int
|
||||
vals map[string]string
|
||||
}
|
||||
|
||||
func newNormalizer(stmt Statement, bindVars map[string]*querypb.BindVariable, prefix string) *normalizer {
|
||||
return &normalizer{
|
||||
stmt: stmt,
|
||||
bindVars: bindVars,
|
||||
prefix: prefix,
|
||||
reserved: GetBindvars(stmt),
|
||||
counter: 1,
|
||||
vals: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// WalkStatement is the top level walk function.
|
||||
// If it encounters a Select, it switches to a mode
|
||||
// where variables are deduped.
|
||||
func (nz *normalizer) WalkStatement(node SQLNode) (bool, error) {
|
||||
switch node := node.(type) {
|
||||
case *Select:
|
||||
_ = Walk(nz.WalkSelect, node)
|
||||
// Don't continue
|
||||
return false, nil
|
||||
case *SQLVal:
|
||||
nz.convertSQLVal(node)
|
||||
case *ComparisonExpr:
|
||||
nz.convertComparison(node)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// WalkSelect normalizes the AST in Select mode.
|
||||
func (nz *normalizer) WalkSelect(node SQLNode) (bool, error) {
|
||||
switch node := node.(type) {
|
||||
case *SQLVal:
|
||||
nz.convertSQLValDedup(node)
|
||||
case *ComparisonExpr:
|
||||
nz.convertComparison(node)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (nz *normalizer) convertSQLValDedup(node *SQLVal) {
|
||||
// If value is too long, don't dedup.
|
||||
// Such values are most likely not for vindexes.
|
||||
// We save a lot of CPU because we avoid building
|
||||
// the key for them.
|
||||
if len(node.Val) > 256 {
|
||||
nz.convertSQLVal(node)
|
||||
return
|
||||
}
|
||||
|
||||
// Make the bindvar
|
||||
bval := nz.sqlToBindvar(node)
|
||||
if bval == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if there's a bindvar for that value already.
|
||||
var key string
|
||||
if bval.Type == sqltypes.VarBinary {
|
||||
// Prefixing strings with "'" ensures that a string
|
||||
// and number that have the same representation don't
|
||||
// collide.
|
||||
key = "'" + string(node.Val)
|
||||
} else {
|
||||
key = string(node.Val)
|
||||
}
|
||||
bvname, ok := nz.vals[key]
|
||||
if !ok {
|
||||
// If there's no such bindvar, make a new one.
|
||||
bvname = nz.newName()
|
||||
nz.vals[key] = bvname
|
||||
nz.bindVars[bvname] = bval
|
||||
}
|
||||
|
||||
// Modify the AST node to a bindvar.
|
||||
node.Type = ValArg
|
||||
node.Val = append([]byte(":"), bvname...)
|
||||
}
|
||||
|
||||
// convertSQLVal converts an SQLVal without the dedup.
|
||||
func (nz *normalizer) convertSQLVal(node *SQLVal) {
|
||||
bval := nz.sqlToBindvar(node)
|
||||
if bval == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bvname := nz.newName()
|
||||
nz.bindVars[bvname] = bval
|
||||
|
||||
node.Type = ValArg
|
||||
node.Val = append([]byte(":"), bvname...)
|
||||
}
|
||||
|
||||
// convertComparison attempts to convert IN clauses to
|
||||
// use the list bind var construct. If it fails, it returns
|
||||
// with no change made. The walk function will then continue
|
||||
// and iterate on converting each individual value into separate
|
||||
// bind vars.
|
||||
func (nz *normalizer) convertComparison(node *ComparisonExpr) {
|
||||
if node.Operator != InStr && node.Operator != NotInStr {
|
||||
return
|
||||
}
|
||||
tupleVals, ok := node.Right.(ValTuple)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// The RHS is a tuple of values.
|
||||
// Make a list bindvar.
|
||||
bvals := &querypb.BindVariable{
|
||||
Type: querypb.Type_TUPLE,
|
||||
}
|
||||
for _, val := range tupleVals {
|
||||
bval := nz.sqlToBindvar(val)
|
||||
if bval == nil {
|
||||
return
|
||||
}
|
||||
bvals.Values = append(bvals.Values, &querypb.Value{
|
||||
Type: bval.Type,
|
||||
Value: bval.Value,
|
||||
})
|
||||
}
|
||||
bvname := nz.newName()
|
||||
nz.bindVars[bvname] = bvals
|
||||
// Modify RHS to be a list bindvar.
|
||||
node.Right = ListArg(append([]byte("::"), bvname...))
|
||||
}
|
||||
|
||||
func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable {
|
||||
if node, ok := node.(*SQLVal); ok {
|
||||
var v sqltypes.Value
|
||||
var err error
|
||||
switch node.Type {
|
||||
case StrVal:
|
||||
v, err = sqltypes.NewValue(sqltypes.VarBinary, node.Val)
|
||||
case IntVal:
|
||||
v, err = sqltypes.NewValue(sqltypes.Int64, node.Val)
|
||||
case FloatVal:
|
||||
v, err = sqltypes.NewValue(sqltypes.Float64, node.Val)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return sqltypes.ValueBindVariable(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nz *normalizer) newName() string {
|
||||
for {
|
||||
newName := fmt.Sprintf("%s%d", nz.prefix, nz.counter)
|
||||
if _, ok := nz.reserved[newName]; !ok {
|
||||
nz.reserved[newName] = struct{}{}
|
||||
return newName
|
||||
}
|
||||
nz.counter++
|
||||
}
|
||||
}
|
||||
|
||||
// GetBindvars returns a map of the bind vars referenced in the statement.
|
||||
// TODO(sougou); This function gets called again from vtgate/planbuilder.
|
||||
// Ideally, this should be done only once.
|
||||
func GetBindvars(stmt Statement) map[string]struct{} {
|
||||
bindvars := make(map[string]struct{})
|
||||
_ = Walk(func(node SQLNode) (kontinue bool, err error) {
|
||||
switch node := node.(type) {
|
||||
case *SQLVal:
|
||||
if node.Type == ValArg {
|
||||
bindvars[string(node.Val[1:])] = struct{}{}
|
||||
}
|
||||
case ListArg:
|
||||
bindvars[string(node[2:])] = struct{}{}
|
||||
}
|
||||
return true, nil
|
||||
}, stmt)
|
||||
return bindvars
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
// ParsedQuery represents a parsed query where
|
||||
// bind locations are precompued for fast substitutions.
|
||||
type ParsedQuery struct {
|
||||
Query string
|
||||
bindLocations []bindLocation
|
||||
}
|
||||
|
||||
type bindLocation struct {
|
||||
offset, length int
|
||||
}
|
||||
|
||||
// NewParsedQuery returns a ParsedQuery of the ast.
|
||||
func NewParsedQuery(node SQLNode) *ParsedQuery {
|
||||
buf := NewTrackedBuffer(nil)
|
||||
buf.Myprintf("%v", node)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
// GenerateQuery generates a query by substituting the specified
|
||||
// bindVariables. The extras parameter specifies special parameters
|
||||
// that can perform custom encoding.
|
||||
func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) ([]byte, error) {
|
||||
if len(pq.bindLocations) == 0 {
|
||||
return []byte(pq.Query), nil
|
||||
}
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(pq.Query)))
|
||||
current := 0
|
||||
for _, loc := range pq.bindLocations {
|
||||
buf.WriteString(pq.Query[current:loc.offset])
|
||||
name := pq.Query[loc.offset : loc.offset+loc.length]
|
||||
if encodable, ok := extras[name[1:]]; ok {
|
||||
encodable.EncodeSQL(buf)
|
||||
} else {
|
||||
supplied, _, err := FetchBindVar(name, bindVariables)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
EncodeValue(buf, supplied)
|
||||
}
|
||||
current = loc.offset + loc.length
|
||||
}
|
||||
buf.WriteString(pq.Query[current:])
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// EncodeValue encodes one bind variable value into the query.
|
||||
func EncodeValue(buf *bytes.Buffer, value *querypb.BindVariable) {
|
||||
if value.Type != querypb.Type_TUPLE {
|
||||
// Since we already check for TUPLE, we don't expect an error.
|
||||
v, _ := sqltypes.BindVariableToValue(value)
|
||||
v.EncodeSQL(buf)
|
||||
return
|
||||
}
|
||||
|
||||
// It's a TUPLE.
|
||||
buf.WriteByte('(')
|
||||
for i, bv := range value.Values {
|
||||
if i != 0 {
|
||||
buf.WriteString(", ")
|
||||
}
|
||||
sqltypes.ProtoToValue(bv).EncodeSQL(buf)
|
||||
}
|
||||
buf.WriteByte(')')
|
||||
}
|
||||
|
||||
// FetchBindVar resolves the bind variable by fetching it from bindVariables.
|
||||
func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) {
|
||||
name = name[1:]
|
||||
if name[0] == ':' {
|
||||
name = name[1:]
|
||||
isList = true
|
||||
}
|
||||
supplied, ok := bindVariables[name]
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("missing bind var %s", name)
|
||||
}
|
||||
|
||||
if isList {
|
||||
if supplied.Type != querypb.Type_TUPLE {
|
||||
return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name)
|
||||
}
|
||||
if len(supplied.Values) == 0 {
|
||||
return nil, false, fmt.Errorf("empty list supplied for %s", name)
|
||||
}
|
||||
return supplied, true, nil
|
||||
}
|
||||
|
||||
if supplied.Type == querypb.Type_TUPLE {
|
||||
return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name)
|
||||
}
|
||||
|
||||
return supplied, false, nil
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
package sqlparser
|
||||
|
||||
import querypb "github.com/xwb1989/sqlparser/dependency/querypb"
|
||||
|
||||
// RedactSQLQuery returns a sql string with the params stripped out for display
|
||||
func RedactSQLQuery(sql string) (string, error) {
|
||||
bv := map[string]*querypb.BindVariable{}
|
||||
sqlStripped, comments := SplitMarginComments(sql)
|
||||
|
||||
stmt, err := Parse(sqlStripped)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
prefix := "redacted"
|
||||
Normalize(stmt, bv, prefix)
|
||||
|
||||
return comments.Leading + String(stmt) + comments.Trailing, nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,950 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/xwb1989/sqlparser/dependency/bytes2"
|
||||
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBufSize = 4096
|
||||
eofChar = 0x100
|
||||
)
|
||||
|
||||
// Tokenizer is the struct used to generate SQL
|
||||
// tokens for the parser.
|
||||
type Tokenizer struct {
|
||||
InStream io.Reader
|
||||
AllowComments bool
|
||||
ForceEOF bool
|
||||
lastChar uint16
|
||||
Position int
|
||||
lastToken []byte
|
||||
LastError error
|
||||
posVarIndex int
|
||||
ParseTree Statement
|
||||
partialDDL *DDL
|
||||
nesting int
|
||||
multi bool
|
||||
specialComment *Tokenizer
|
||||
|
||||
buf []byte
|
||||
bufPos int
|
||||
bufSize int
|
||||
}
|
||||
|
||||
// NewStringTokenizer creates a new Tokenizer for the
|
||||
// sql string.
|
||||
func NewStringTokenizer(sql string) *Tokenizer {
|
||||
buf := []byte(sql)
|
||||
return &Tokenizer{
|
||||
buf: buf,
|
||||
bufSize: len(buf),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTokenizer creates a new Tokenizer reading a sql
|
||||
// string from the io.Reader.
|
||||
func NewTokenizer(r io.Reader) *Tokenizer {
|
||||
return &Tokenizer{
|
||||
InStream: r,
|
||||
buf: make([]byte, defaultBufSize),
|
||||
}
|
||||
}
|
||||
|
||||
// keywords is a map of mysql keywords that fall into two categories:
|
||||
// 1) keywords considered reserved by MySQL
|
||||
// 2) keywords for us to handle specially in sql.y
|
||||
//
|
||||
// Those marked as UNUSED are likely reserved keywords. We add them here so that
|
||||
// when rewriting queries we can properly backtick quote them so they don't cause issues
|
||||
//
|
||||
// NOTE: If you add new keywords, add them also to the reserved_keywords or
|
||||
// non_reserved_keywords grammar in sql.y -- this will allow the keyword to be used
|
||||
// in identifiers. See the docs for each grammar to determine which one to put it into.
|
||||
var keywords = map[string]int{
|
||||
"accessible": UNUSED,
|
||||
"add": ADD,
|
||||
"against": AGAINST,
|
||||
"all": ALL,
|
||||
"alter": ALTER,
|
||||
"analyze": ANALYZE,
|
||||
"and": AND,
|
||||
"as": AS,
|
||||
"asc": ASC,
|
||||
"asensitive": UNUSED,
|
||||
"auto_increment": AUTO_INCREMENT,
|
||||
"before": UNUSED,
|
||||
"begin": BEGIN,
|
||||
"between": BETWEEN,
|
||||
"bigint": BIGINT,
|
||||
"binary": BINARY,
|
||||
"_binary": UNDERSCORE_BINARY,
|
||||
"bit": BIT,
|
||||
"blob": BLOB,
|
||||
"bool": BOOL,
|
||||
"boolean": BOOLEAN,
|
||||
"both": UNUSED,
|
||||
"by": BY,
|
||||
"call": UNUSED,
|
||||
"cascade": UNUSED,
|
||||
"case": CASE,
|
||||
"cast": CAST,
|
||||
"change": UNUSED,
|
||||
"char": CHAR,
|
||||
"character": CHARACTER,
|
||||
"charset": CHARSET,
|
||||
"check": UNUSED,
|
||||
"collate": COLLATE,
|
||||
"column": COLUMN,
|
||||
"comment": COMMENT_KEYWORD,
|
||||
"committed": COMMITTED,
|
||||
"commit": COMMIT,
|
||||
"condition": UNUSED,
|
||||
"constraint": CONSTRAINT,
|
||||
"continue": UNUSED,
|
||||
"convert": CONVERT,
|
||||
"substr": SUBSTR,
|
||||
"substring": SUBSTRING,
|
||||
"create": CREATE,
|
||||
"cross": CROSS,
|
||||
"current_date": CURRENT_DATE,
|
||||
"current_time": CURRENT_TIME,
|
||||
"current_timestamp": CURRENT_TIMESTAMP,
|
||||
"current_user": UNUSED,
|
||||
"cursor": UNUSED,
|
||||
"database": DATABASE,
|
||||
"databases": DATABASES,
|
||||
"day_hour": UNUSED,
|
||||
"day_microsecond": UNUSED,
|
||||
"day_minute": UNUSED,
|
||||
"day_second": UNUSED,
|
||||
"date": DATE,
|
||||
"datetime": DATETIME,
|
||||
"dec": UNUSED,
|
||||
"decimal": DECIMAL,
|
||||
"declare": UNUSED,
|
||||
"default": DEFAULT,
|
||||
"delayed": UNUSED,
|
||||
"delete": DELETE,
|
||||
"desc": DESC,
|
||||
"describe": DESCRIBE,
|
||||
"deterministic": UNUSED,
|
||||
"distinct": DISTINCT,
|
||||
"distinctrow": UNUSED,
|
||||
"div": DIV,
|
||||
"double": DOUBLE,
|
||||
"drop": DROP,
|
||||
"duplicate": DUPLICATE,
|
||||
"each": UNUSED,
|
||||
"else": ELSE,
|
||||
"elseif": UNUSED,
|
||||
"enclosed": UNUSED,
|
||||
"end": END,
|
||||
"enum": ENUM,
|
||||
"escape": ESCAPE,
|
||||
"escaped": UNUSED,
|
||||
"exists": EXISTS,
|
||||
"exit": UNUSED,
|
||||
"explain": EXPLAIN,
|
||||
"expansion": EXPANSION,
|
||||
"extended": EXTENDED,
|
||||
"false": FALSE,
|
||||
"fetch": UNUSED,
|
||||
"float": FLOAT_TYPE,
|
||||
"float4": UNUSED,
|
||||
"float8": UNUSED,
|
||||
"for": FOR,
|
||||
"force": FORCE,
|
||||
"foreign": FOREIGN,
|
||||
"from": FROM,
|
||||
"full": FULL,
|
||||
"fulltext": FULLTEXT,
|
||||
"generated": UNUSED,
|
||||
"geometry": GEOMETRY,
|
||||
"geometrycollection": GEOMETRYCOLLECTION,
|
||||
"get": UNUSED,
|
||||
"global": GLOBAL,
|
||||
"grant": UNUSED,
|
||||
"group": GROUP,
|
||||
"group_concat": GROUP_CONCAT,
|
||||
"having": HAVING,
|
||||
"high_priority": UNUSED,
|
||||
"hour_microsecond": UNUSED,
|
||||
"hour_minute": UNUSED,
|
||||
"hour_second": UNUSED,
|
||||
"if": IF,
|
||||
"ignore": IGNORE,
|
||||
"in": IN,
|
||||
"index": INDEX,
|
||||
"infile": UNUSED,
|
||||
"inout": UNUSED,
|
||||
"inner": INNER,
|
||||
"insensitive": UNUSED,
|
||||
"insert": INSERT,
|
||||
"int": INT,
|
||||
"int1": UNUSED,
|
||||
"int2": UNUSED,
|
||||
"int3": UNUSED,
|
||||
"int4": UNUSED,
|
||||
"int8": UNUSED,
|
||||
"integer": INTEGER,
|
||||
"interval": INTERVAL,
|
||||
"into": INTO,
|
||||
"io_after_gtids": UNUSED,
|
||||
"is": IS,
|
||||
"isolation": ISOLATION,
|
||||
"iterate": UNUSED,
|
||||
"join": JOIN,
|
||||
"json": JSON,
|
||||
"key": KEY,
|
||||
"keys": KEYS,
|
||||
"key_block_size": KEY_BLOCK_SIZE,
|
||||
"kill": UNUSED,
|
||||
"language": LANGUAGE,
|
||||
"last_insert_id": LAST_INSERT_ID,
|
||||
"leading": UNUSED,
|
||||
"leave": UNUSED,
|
||||
"left": LEFT,
|
||||
"less": LESS,
|
||||
"level": LEVEL,
|
||||
"like": LIKE,
|
||||
"limit": LIMIT,
|
||||
"linear": UNUSED,
|
||||
"lines": UNUSED,
|
||||
"linestring": LINESTRING,
|
||||
"load": UNUSED,
|
||||
"localtime": LOCALTIME,
|
||||
"localtimestamp": LOCALTIMESTAMP,
|
||||
"lock": LOCK,
|
||||
"long": UNUSED,
|
||||
"longblob": LONGBLOB,
|
||||
"longtext": LONGTEXT,
|
||||
"loop": UNUSED,
|
||||
"low_priority": UNUSED,
|
||||
"master_bind": UNUSED,
|
||||
"match": MATCH,
|
||||
"maxvalue": MAXVALUE,
|
||||
"mediumblob": MEDIUMBLOB,
|
||||
"mediumint": MEDIUMINT,
|
||||
"mediumtext": MEDIUMTEXT,
|
||||
"middleint": UNUSED,
|
||||
"minute_microsecond": UNUSED,
|
||||
"minute_second": UNUSED,
|
||||
"mod": MOD,
|
||||
"mode": MODE,
|
||||
"modifies": UNUSED,
|
||||
"multilinestring": MULTILINESTRING,
|
||||
"multipoint": MULTIPOINT,
|
||||
"multipolygon": MULTIPOLYGON,
|
||||
"names": NAMES,
|
||||
"natural": NATURAL,
|
||||
"nchar": NCHAR,
|
||||
"next": NEXT,
|
||||
"not": NOT,
|
||||
"no_write_to_binlog": UNUSED,
|
||||
"null": NULL,
|
||||
"numeric": NUMERIC,
|
||||
"offset": OFFSET,
|
||||
"on": ON,
|
||||
"only": ONLY,
|
||||
"optimize": OPTIMIZE,
|
||||
"optimizer_costs": UNUSED,
|
||||
"option": UNUSED,
|
||||
"optionally": UNUSED,
|
||||
"or": OR,
|
||||
"order": ORDER,
|
||||
"out": UNUSED,
|
||||
"outer": OUTER,
|
||||
"outfile": UNUSED,
|
||||
"partition": PARTITION,
|
||||
"point": POINT,
|
||||
"polygon": POLYGON,
|
||||
"precision": UNUSED,
|
||||
"primary": PRIMARY,
|
||||
"processlist": PROCESSLIST,
|
||||
"procedure": PROCEDURE,
|
||||
"query": QUERY,
|
||||
"range": UNUSED,
|
||||
"read": READ,
|
||||
"reads": UNUSED,
|
||||
"read_write": UNUSED,
|
||||
"real": REAL,
|
||||
"references": UNUSED,
|
||||
"regexp": REGEXP,
|
||||
"release": UNUSED,
|
||||
"rename": RENAME,
|
||||
"reorganize": REORGANIZE,
|
||||
"repair": REPAIR,
|
||||
"repeat": UNUSED,
|
||||
"repeatable": REPEATABLE,
|
||||
"replace": REPLACE,
|
||||
"require": UNUSED,
|
||||
"resignal": UNUSED,
|
||||
"restrict": UNUSED,
|
||||
"return": UNUSED,
|
||||
"revoke": UNUSED,
|
||||
"right": RIGHT,
|
||||
"rlike": REGEXP,
|
||||
"rollback": ROLLBACK,
|
||||
"schema": SCHEMA,
|
||||
"schemas": UNUSED,
|
||||
"second_microsecond": UNUSED,
|
||||
"select": SELECT,
|
||||
"sensitive": UNUSED,
|
||||
"separator": SEPARATOR,
|
||||
"serializable": SERIALIZABLE,
|
||||
"session": SESSION,
|
||||
"set": SET,
|
||||
"share": SHARE,
|
||||
"show": SHOW,
|
||||
"signal": UNUSED,
|
||||
"signed": SIGNED,
|
||||
"smallint": SMALLINT,
|
||||
"spatial": SPATIAL,
|
||||
"specific": UNUSED,
|
||||
"sql": UNUSED,
|
||||
"sqlexception": UNUSED,
|
||||
"sqlstate": UNUSED,
|
||||
"sqlwarning": UNUSED,
|
||||
"sql_big_result": UNUSED,
|
||||
"sql_cache": SQL_CACHE,
|
||||
"sql_calc_found_rows": UNUSED,
|
||||
"sql_no_cache": SQL_NO_CACHE,
|
||||
"sql_small_result": UNUSED,
|
||||
"ssl": UNUSED,
|
||||
"start": START,
|
||||
"starting": UNUSED,
|
||||
"status": STATUS,
|
||||
"stored": UNUSED,
|
||||
"straight_join": STRAIGHT_JOIN,
|
||||
"stream": STREAM,
|
||||
"table": TABLE,
|
||||
"tables": TABLES,
|
||||
"terminated": UNUSED,
|
||||
"text": TEXT,
|
||||
"than": THAN,
|
||||
"then": THEN,
|
||||
"time": TIME,
|
||||
"timestamp": TIMESTAMP,
|
||||
"tinyblob": TINYBLOB,
|
||||
"tinyint": TINYINT,
|
||||
"tinytext": TINYTEXT,
|
||||
"to": TO,
|
||||
"trailing": UNUSED,
|
||||
"transaction": TRANSACTION,
|
||||
"trigger": TRIGGER,
|
||||
"true": TRUE,
|
||||
"truncate": TRUNCATE,
|
||||
"uncommitted": UNCOMMITTED,
|
||||
"undo": UNUSED,
|
||||
"union": UNION,
|
||||
"unique": UNIQUE,
|
||||
"unlock": UNUSED,
|
||||
"unsigned": UNSIGNED,
|
||||
"update": UPDATE,
|
||||
"usage": UNUSED,
|
||||
"use": USE,
|
||||
"using": USING,
|
||||
"utc_date": UTC_DATE,
|
||||
"utc_time": UTC_TIME,
|
||||
"utc_timestamp": UTC_TIMESTAMP,
|
||||
"values": VALUES,
|
||||
"variables": VARIABLES,
|
||||
"varbinary": VARBINARY,
|
||||
"varchar": VARCHAR,
|
||||
"varcharacter": UNUSED,
|
||||
"varying": UNUSED,
|
||||
"virtual": UNUSED,
|
||||
"vindex": VINDEX,
|
||||
"vindexes": VINDEXES,
|
||||
"view": VIEW,
|
||||
"vitess_keyspaces": VITESS_KEYSPACES,
|
||||
"vitess_shards": VITESS_SHARDS,
|
||||
"vitess_tablets": VITESS_TABLETS,
|
||||
"vschema_tables": VSCHEMA_TABLES,
|
||||
"when": WHEN,
|
||||
"where": WHERE,
|
||||
"while": UNUSED,
|
||||
"with": WITH,
|
||||
"write": WRITE,
|
||||
"xor": UNUSED,
|
||||
"year": YEAR,
|
||||
"year_month": UNUSED,
|
||||
"zerofill": ZEROFILL,
|
||||
}
|
||||
|
||||
// keywordStrings contains the reverse mapping of token to keyword strings
|
||||
var keywordStrings = map[int]string{}
|
||||
|
||||
func init() {
|
||||
for str, id := range keywords {
|
||||
if id == UNUSED {
|
||||
continue
|
||||
}
|
||||
keywordStrings[id] = str
|
||||
}
|
||||
}
|
||||
|
||||
// KeywordString returns the string corresponding to the given keyword
|
||||
func KeywordString(id int) string {
|
||||
str, ok := keywordStrings[id]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Lex returns the next token form the Tokenizer.
|
||||
// This function is used by go yacc.
|
||||
func (tkn *Tokenizer) Lex(lval *yySymType) int {
|
||||
typ, val := tkn.Scan()
|
||||
for typ == COMMENT {
|
||||
if tkn.AllowComments {
|
||||
break
|
||||
}
|
||||
typ, val = tkn.Scan()
|
||||
}
|
||||
lval.bytes = val
|
||||
tkn.lastToken = val
|
||||
return typ
|
||||
}
|
||||
|
||||
// Error is called by go yacc if there's a parsing error.
|
||||
func (tkn *Tokenizer) Error(err string) {
|
||||
buf := &bytes2.Buffer{}
|
||||
if tkn.lastToken != nil {
|
||||
fmt.Fprintf(buf, "%s at position %v near '%s'", err, tkn.Position, tkn.lastToken)
|
||||
} else {
|
||||
fmt.Fprintf(buf, "%s at position %v", err, tkn.Position)
|
||||
}
|
||||
tkn.LastError = errors.New(buf.String())
|
||||
|
||||
// Try and re-sync to the next statement
|
||||
if tkn.lastChar != ';' {
|
||||
tkn.skipStatement()
|
||||
}
|
||||
}
|
||||
|
||||
// Scan scans the tokenizer for the next token and returns
|
||||
// the token type and an optional value.
|
||||
func (tkn *Tokenizer) Scan() (int, []byte) {
|
||||
if tkn.specialComment != nil {
|
||||
// Enter specialComment scan mode.
|
||||
// for scanning such kind of comment: /*! MySQL-specific code */
|
||||
specialComment := tkn.specialComment
|
||||
tok, val := specialComment.Scan()
|
||||
if tok != 0 {
|
||||
// return the specialComment scan result as the result
|
||||
return tok, val
|
||||
}
|
||||
// leave specialComment scan mode after all stream consumed.
|
||||
tkn.specialComment = nil
|
||||
}
|
||||
if tkn.lastChar == 0 {
|
||||
tkn.next()
|
||||
}
|
||||
|
||||
if tkn.ForceEOF {
|
||||
tkn.skipStatement()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tkn.skipBlank()
|
||||
switch ch := tkn.lastChar; {
|
||||
case isLetter(ch):
|
||||
tkn.next()
|
||||
if ch == 'X' || ch == 'x' {
|
||||
if tkn.lastChar == '\'' {
|
||||
tkn.next()
|
||||
return tkn.scanHex()
|
||||
}
|
||||
}
|
||||
if ch == 'B' || ch == 'b' {
|
||||
if tkn.lastChar == '\'' {
|
||||
tkn.next()
|
||||
return tkn.scanBitLiteral()
|
||||
}
|
||||
}
|
||||
isDbSystemVariable := false
|
||||
if ch == '@' && tkn.lastChar == '@' {
|
||||
isDbSystemVariable = true
|
||||
}
|
||||
return tkn.scanIdentifier(byte(ch), isDbSystemVariable)
|
||||
case isDigit(ch):
|
||||
return tkn.scanNumber(false)
|
||||
case ch == ':':
|
||||
return tkn.scanBindVar()
|
||||
case ch == ';' && tkn.multi:
|
||||
return 0, nil
|
||||
default:
|
||||
tkn.next()
|
||||
switch ch {
|
||||
case eofChar:
|
||||
return 0, nil
|
||||
case '=', ',', ';', '(', ')', '+', '*', '%', '^', '~':
|
||||
return int(ch), nil
|
||||
case '&':
|
||||
if tkn.lastChar == '&' {
|
||||
tkn.next()
|
||||
return AND, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '|':
|
||||
if tkn.lastChar == '|' {
|
||||
tkn.next()
|
||||
return OR, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '?':
|
||||
tkn.posVarIndex++
|
||||
buf := new(bytes2.Buffer)
|
||||
fmt.Fprintf(buf, ":v%d", tkn.posVarIndex)
|
||||
return VALUE_ARG, buf.Bytes()
|
||||
case '.':
|
||||
if isDigit(tkn.lastChar) {
|
||||
return tkn.scanNumber(true)
|
||||
}
|
||||
return int(ch), nil
|
||||
case '/':
|
||||
switch tkn.lastChar {
|
||||
case '/':
|
||||
tkn.next()
|
||||
return tkn.scanCommentType1("//")
|
||||
case '*':
|
||||
tkn.next()
|
||||
switch tkn.lastChar {
|
||||
case '!':
|
||||
return tkn.scanMySQLSpecificComment()
|
||||
default:
|
||||
return tkn.scanCommentType2()
|
||||
}
|
||||
default:
|
||||
return int(ch), nil
|
||||
}
|
||||
case '#':
|
||||
return tkn.scanCommentType1("#")
|
||||
case '-':
|
||||
switch tkn.lastChar {
|
||||
case '-':
|
||||
tkn.next()
|
||||
return tkn.scanCommentType1("--")
|
||||
case '>':
|
||||
tkn.next()
|
||||
if tkn.lastChar == '>' {
|
||||
tkn.next()
|
||||
return JSON_UNQUOTE_EXTRACT_OP, nil
|
||||
}
|
||||
return JSON_EXTRACT_OP, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '<':
|
||||
switch tkn.lastChar {
|
||||
case '>':
|
||||
tkn.next()
|
||||
return NE, nil
|
||||
case '<':
|
||||
tkn.next()
|
||||
return SHIFT_LEFT, nil
|
||||
case '=':
|
||||
tkn.next()
|
||||
switch tkn.lastChar {
|
||||
case '>':
|
||||
tkn.next()
|
||||
return NULL_SAFE_EQUAL, nil
|
||||
default:
|
||||
return LE, nil
|
||||
}
|
||||
default:
|
||||
return int(ch), nil
|
||||
}
|
||||
case '>':
|
||||
switch tkn.lastChar {
|
||||
case '=':
|
||||
tkn.next()
|
||||
return GE, nil
|
||||
case '>':
|
||||
tkn.next()
|
||||
return SHIFT_RIGHT, nil
|
||||
default:
|
||||
return int(ch), nil
|
||||
}
|
||||
case '!':
|
||||
if tkn.lastChar == '=' {
|
||||
tkn.next()
|
||||
return NE, nil
|
||||
}
|
||||
return int(ch), nil
|
||||
case '\'', '"':
|
||||
return tkn.scanString(ch, STRING)
|
||||
case '`':
|
||||
return tkn.scanLiteralIdentifier()
|
||||
default:
|
||||
return LEX_ERROR, []byte{byte(ch)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// skipStatement scans until the EOF, or end of statement is encountered.
|
||||
func (tkn *Tokenizer) skipStatement() {
|
||||
ch := tkn.lastChar
|
||||
for ch != ';' && ch != eofChar {
|
||||
tkn.next()
|
||||
ch = tkn.lastChar
|
||||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) skipBlank() {
|
||||
ch := tkn.lastChar
|
||||
for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' {
|
||||
tkn.next()
|
||||
ch = tkn.lastChar
|
||||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanIdentifier(firstByte byte, isDbSystemVariable bool) (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteByte(firstByte)
|
||||
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || (isDbSystemVariable && isCarat(tkn.lastChar)) {
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
lowered := bytes.ToLower(buffer.Bytes())
|
||||
loweredStr := string(lowered)
|
||||
if keywordID, found := keywords[loweredStr]; found {
|
||||
return keywordID, lowered
|
||||
}
|
||||
// dual must always be case-insensitive
|
||||
if loweredStr == "dual" {
|
||||
return ID, lowered
|
||||
}
|
||||
return ID, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanHex() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
tkn.scanMantissa(16, buffer)
|
||||
if tkn.lastChar != '\'' {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.next()
|
||||
if buffer.Len()%2 != 0 {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
return HEX, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanBitLiteral() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
tkn.scanMantissa(2, buffer)
|
||||
if tkn.lastChar != '\'' {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.next()
|
||||
return BIT_LITERAL, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
backTickSeen := false
|
||||
for {
|
||||
if backTickSeen {
|
||||
if tkn.lastChar != '`' {
|
||||
break
|
||||
}
|
||||
backTickSeen = false
|
||||
buffer.WriteByte('`')
|
||||
tkn.next()
|
||||
continue
|
||||
}
|
||||
// The previous char was not a backtick.
|
||||
switch tkn.lastChar {
|
||||
case '`':
|
||||
backTickSeen = true
|
||||
case eofChar:
|
||||
// Premature EOF.
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
default:
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
}
|
||||
tkn.next()
|
||||
}
|
||||
if buffer.Len() == 0 {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
return ID, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanBindVar() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
token := VALUE_ARG
|
||||
tkn.next()
|
||||
if tkn.lastChar == ':' {
|
||||
token = LIST_ARG
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
if !isLetter(tkn.lastChar) {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || tkn.lastChar == '.' {
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
return token, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes2.Buffer) {
|
||||
for digitVal(tkn.lastChar) < base {
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) {
|
||||
token := INTEGRAL
|
||||
buffer := &bytes2.Buffer{}
|
||||
if seenDecimalPoint {
|
||||
token = FLOAT
|
||||
buffer.WriteByte('.')
|
||||
tkn.scanMantissa(10, buffer)
|
||||
goto exponent
|
||||
}
|
||||
|
||||
// 0x construct.
|
||||
if tkn.lastChar == '0' {
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == 'x' || tkn.lastChar == 'X' {
|
||||
token = HEXNUM
|
||||
tkn.consumeNext(buffer)
|
||||
tkn.scanMantissa(16, buffer)
|
||||
goto exit
|
||||
}
|
||||
}
|
||||
|
||||
tkn.scanMantissa(10, buffer)
|
||||
|
||||
if tkn.lastChar == '.' {
|
||||
token = FLOAT
|
||||
tkn.consumeNext(buffer)
|
||||
tkn.scanMantissa(10, buffer)
|
||||
}
|
||||
|
||||
exponent:
|
||||
if tkn.lastChar == 'e' || tkn.lastChar == 'E' {
|
||||
token = FLOAT
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == '+' || tkn.lastChar == '-' {
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
tkn.scanMantissa(10, buffer)
|
||||
}
|
||||
|
||||
exit:
|
||||
// A letter cannot immediately follow a number.
|
||||
if isLetter(tkn.lastChar) {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
|
||||
return token, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
|
||||
var buffer bytes2.Buffer
|
||||
for {
|
||||
ch := tkn.lastChar
|
||||
if ch == eofChar {
|
||||
// Unterminated string.
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
|
||||
if ch != delim && ch != '\\' {
|
||||
buffer.WriteByte(byte(ch))
|
||||
|
||||
// Scan ahead to the next interesting character.
|
||||
start := tkn.bufPos
|
||||
for ; tkn.bufPos < tkn.bufSize; tkn.bufPos++ {
|
||||
ch = uint16(tkn.buf[tkn.bufPos])
|
||||
if ch == delim || ch == '\\' {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buffer.Write(tkn.buf[start:tkn.bufPos])
|
||||
tkn.Position += (tkn.bufPos - start)
|
||||
|
||||
if tkn.bufPos >= tkn.bufSize {
|
||||
// Reached the end of the buffer without finding a delim or
|
||||
// escape character.
|
||||
tkn.next()
|
||||
continue
|
||||
}
|
||||
|
||||
tkn.bufPos++
|
||||
tkn.Position++
|
||||
}
|
||||
tkn.next() // Read one past the delim or escape character.
|
||||
|
||||
if ch == '\\' {
|
||||
if tkn.lastChar == eofChar {
|
||||
// String terminates mid escape character.
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DontEscape {
|
||||
ch = tkn.lastChar
|
||||
} else {
|
||||
ch = uint16(decodedChar)
|
||||
}
|
||||
|
||||
} else if ch == delim && tkn.lastChar != delim {
|
||||
// Correctly terminated string, which is not a double delim.
|
||||
break
|
||||
}
|
||||
|
||||
buffer.WriteByte(byte(ch))
|
||||
tkn.next()
|
||||
}
|
||||
|
||||
return typ, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteString(prefix)
|
||||
for tkn.lastChar != eofChar {
|
||||
if tkn.lastChar == '\n' {
|
||||
tkn.consumeNext(buffer)
|
||||
break
|
||||
}
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
return COMMENT, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanCommentType2() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteString("/*")
|
||||
for {
|
||||
if tkn.lastChar == '*' {
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == '/' {
|
||||
tkn.consumeNext(buffer)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tkn.lastChar == eofChar {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
return COMMENT, buffer.Bytes()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteString("/*!")
|
||||
tkn.next()
|
||||
for {
|
||||
if tkn.lastChar == '*' {
|
||||
tkn.consumeNext(buffer)
|
||||
if tkn.lastChar == '/' {
|
||||
tkn.consumeNext(buffer)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tkn.lastChar == eofChar {
|
||||
return LEX_ERROR, buffer.Bytes()
|
||||
}
|
||||
tkn.consumeNext(buffer)
|
||||
}
|
||||
_, sql := ExtractMysqlComment(buffer.String())
|
||||
tkn.specialComment = NewStringTokenizer(sql)
|
||||
return tkn.Scan()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) consumeNext(buffer *bytes2.Buffer) {
|
||||
if tkn.lastChar == eofChar {
|
||||
// This should never happen.
|
||||
panic("unexpected EOF")
|
||||
}
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) next() {
|
||||
if tkn.bufPos >= tkn.bufSize && tkn.InStream != nil {
|
||||
// Try and refill the buffer
|
||||
var err error
|
||||
tkn.bufPos = 0
|
||||
if tkn.bufSize, err = tkn.InStream.Read(tkn.buf); err != io.EOF && err != nil {
|
||||
tkn.LastError = err
|
||||
}
|
||||
}
|
||||
|
||||
if tkn.bufPos >= tkn.bufSize {
|
||||
if tkn.lastChar != eofChar {
|
||||
tkn.Position++
|
||||
tkn.lastChar = eofChar
|
||||
}
|
||||
} else {
|
||||
tkn.Position++
|
||||
tkn.lastChar = uint16(tkn.buf[tkn.bufPos])
|
||||
tkn.bufPos++
|
||||
}
|
||||
}
|
||||
|
||||
// reset clears any internal state.
|
||||
func (tkn *Tokenizer) reset() {
|
||||
tkn.ParseTree = nil
|
||||
tkn.partialDDL = nil
|
||||
tkn.specialComment = nil
|
||||
tkn.posVarIndex = 0
|
||||
tkn.nesting = 0
|
||||
tkn.ForceEOF = false
|
||||
}
|
||||
|
||||
func isLetter(ch uint16) bool {
|
||||
return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@'
|
||||
}
|
||||
|
||||
func isCarat(ch uint16) bool {
|
||||
return ch == '.' || ch == '\'' || ch == '"' || ch == '`'
|
||||
}
|
||||
|
||||
func digitVal(ch uint16) int {
|
||||
switch {
|
||||
case '0' <= ch && ch <= '9':
|
||||
return int(ch) - '0'
|
||||
case 'a' <= ch && ch <= 'f':
|
||||
return int(ch) - 'a' + 10
|
||||
case 'A' <= ch && ch <= 'F':
|
||||
return int(ch) - 'A' + 10
|
||||
}
|
||||
return 16 // larger than any legal digit val
|
||||
}
|
||||
|
||||
func isDigit(ch uint16) bool {
|
||||
return '0' <= ch && ch <= '9'
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
/*
|
||||
Copyright 2017 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// NodeFormatter defines the signature of a custom node formatter
|
||||
// function that can be given to TrackedBuffer for code generation.
|
||||
type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
|
||||
|
||||
// TrackedBuffer is used to rebuild a query from the ast.
|
||||
// bindLocations keeps track of locations in the buffer that
|
||||
// use bind variables for efficient future substitutions.
|
||||
// nodeFormatter is the formatting function the buffer will
|
||||
// use to format a node. By default(nil), it's FormatNode.
|
||||
// But you can supply a different formatting function if you
|
||||
// want to generate a query that's different from the default.
|
||||
type TrackedBuffer struct {
|
||||
*bytes.Buffer
|
||||
bindLocations []bindLocation
|
||||
nodeFormatter NodeFormatter
|
||||
}
|
||||
|
||||
// NewTrackedBuffer creates a new TrackedBuffer.
|
||||
func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
|
||||
return &TrackedBuffer{
|
||||
Buffer: new(bytes.Buffer),
|
||||
nodeFormatter: nodeFormatter,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteNode function, initiates the writing of a single SQLNode tree by passing
|
||||
// through to Myprintf with a default format string
|
||||
func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
|
||||
buf.Myprintf("%v", node)
|
||||
return buf
|
||||
}
|
||||
|
||||
// Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
|
||||
// Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
|
||||
// which case it adds tracking info for future substitutions.
|
||||
//
|
||||
// The name must be something other than the usual Printf() to avoid "go vet"
|
||||
// warnings due to our custom format specifiers.
|
||||
func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
|
||||
end := len(format)
|
||||
fieldnum := 0
|
||||
for i := 0; i < end; {
|
||||
lasti := i
|
||||
for i < end && format[i] != '%' {
|
||||
i++
|
||||
}
|
||||
if i > lasti {
|
||||
buf.WriteString(format[lasti:i])
|
||||
}
|
||||
if i >= end {
|
||||
break
|
||||
}
|
||||
i++ // '%'
|
||||
switch format[i] {
|
||||
case 'c':
|
||||
switch v := values[fieldnum].(type) {
|
||||
case byte:
|
||||
buf.WriteByte(v)
|
||||
case rune:
|
||||
buf.WriteRune(v)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
|
||||
}
|
||||
case 's':
|
||||
switch v := values[fieldnum].(type) {
|
||||
case []byte:
|
||||
buf.Write(v)
|
||||
case string:
|
||||
buf.WriteString(v)
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
|
||||
}
|
||||
case 'v':
|
||||
node := values[fieldnum].(SQLNode)
|
||||
if buf.nodeFormatter == nil {
|
||||
node.Format(buf)
|
||||
} else {
|
||||
buf.nodeFormatter(buf, node)
|
||||
}
|
||||
case 'a':
|
||||
buf.WriteArg(values[fieldnum].(string))
|
||||
default:
|
||||
panic("unexpected")
|
||||
}
|
||||
fieldnum++
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// WriteArg writes a value argument into the buffer along with
|
||||
// tracking information for future substitutions. arg must contain
|
||||
// the ":" or "::" prefix.
|
||||
func (buf *TrackedBuffer) WriteArg(arg string) {
|
||||
buf.bindLocations = append(buf.bindLocations, bindLocation{
|
||||
offset: buf.Len(),
|
||||
length: len(arg),
|
||||
})
|
||||
buf.WriteString(arg)
|
||||
}
|
||||
|
||||
// ParsedQuery returns a ParsedQuery that contains bind
|
||||
// locations for easy substitution.
|
||||
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
|
||||
return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
|
||||
}
|
||||
|
||||
// HasBindVars returns true if the parsed query uses bind vars.
|
||||
func (buf *TrackedBuffer) HasBindVars() bool {
|
||||
return len(buf.bindLocations) != 0
|
||||
}
|
||||
|
||||
// BuildParsedQuery builds a ParsedQuery from the input.
|
||||
func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
|
||||
buf := NewTrackedBuffer(nil)
|
||||
buf.Myprintf(in, vars...)
|
||||
return buf.ParsedQuery()
|
||||
}
|
|
@ -460,6 +460,12 @@
|
|||
"revision": "7e6a47b300b10be9449610a6ff4fbae17d6e95b6",
|
||||
"revisionTime": "2018-01-16T16:19:11Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "mA7isU/nIAT5ytwIzK65H0vlVqI=",
|
||||
"path": "github.com/klauspost/compress/flate",
|
||||
"revision": "5fb1f31b0a61e9858f12f39266e059848a5f1cea",
|
||||
"revisionTime": "2018-04-02T19:26:10Z"
|
||||
},
|
||||
{
|
||||
"path": "github.com/klauspost/cpuid",
|
||||
"revision": "349c675778172472f5e8f3a3e0fe187e302e5a10",
|
||||
|
@ -471,6 +477,12 @@
|
|||
"revision": "cb6bfca970f6908083f26f39a79009d608efd5cd",
|
||||
"revisionTime": "2016-10-16T15:41:25Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "o/9oGuVccfxPYuRdzvW0lv7Zbzg=",
|
||||
"path": "github.com/klauspost/pgzip",
|
||||
"revision": "90b2c57fba35a1dd05cb40f9200722763808d99b",
|
||||
"revisionTime": "2018-06-06T15:09:39Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "ehsrWipiGIWqa4To8TmelIx06vI=",
|
||||
"path": "github.com/klauspost/reedsolomon",
|
||||
|
@ -790,6 +802,42 @@
|
|||
"revision": "173748da739a410c5b0b813b956f89ff94730b4c",
|
||||
"revisionTime": "2016-08-30T17:39:30Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "MWqyOvDMkW+XYe2RJ5mplvut+aE=",
|
||||
"path": "github.com/ugorji/go/codec",
|
||||
"revision": "ded73eae5db7e7a0ef6f55aace87a2873c5d2b74",
|
||||
"revisionTime": "2017-01-07T13:32:03Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "6ksZHYhLc3yOzTbcWKb3bDENhD4=",
|
||||
"path": "github.com/xwb1989/sqlparser",
|
||||
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
|
||||
"revisionTime": "2018-06-06T15:21:19Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "L/Q8Ylbo+wnj5whDFfMxxwyxmdo=",
|
||||
"path": "github.com/xwb1989/sqlparser/dependency/bytes2",
|
||||
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
|
||||
"revisionTime": "2018-06-06T15:21:19Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "f9K0yQdwD0Z2yc3bmDw2uqXt4hU=",
|
||||
"path": "github.com/xwb1989/sqlparser/dependency/hack",
|
||||
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
|
||||
"revisionTime": "2018-06-06T15:21:19Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "xpu1JU/VZ7gGNbU5Ol9Nm1oS4tY=",
|
||||
"path": "github.com/xwb1989/sqlparser/dependency/querypb",
|
||||
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
|
||||
"revisionTime": "2018-06-06T15:21:19Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "KbNIySCQgMG81TRMJp1IDRfSgv8=",
|
||||
"path": "github.com/xwb1989/sqlparser/dependency/sqltypes",
|
||||
"revision": "120387863bf27d04bc07db8015110a6e96d0146c",
|
||||
"revisionTime": "2018-06-06T15:21:19Z"
|
||||
},
|
||||
{
|
||||
"checksumSHA1": "6NS7FWJl1FobB+Xfe4SzBGD+75g=",
|
||||
"path": "go.uber.org/atomic",
|
||||
|
|
Loading…
Reference in New Issue