Refactor streaming signatureV4 w/ state machine (#2862)

* Refactor streaming signatureV4 w/ state machine

- Used state machine to make transitions between reading chunk header,
  chunk data and trailer explicit.

* debug: add print/panic statements to gather more info on CI failure

* Persist lastChunk status between Read() on ChunkReader

... remove panic() which was added as interim aid for debugging.

* Add unit-tests to cover v4 streaming signature
This commit is contained in:
Krishnan Parthasarathi
2016-10-10 14:12:32 +05:30
committed by Harshavardhana
parent 3cfb23750a
commit 2d5e988a6d
5 changed files with 253 additions and 80 deletions

View File

@@ -175,6 +175,7 @@ func newSignV4ChunkedReader(req *http.Request) (io.Reader, APIErrorCode) {
seedSignature: seedSignature,
seedDate: seedDate,
chunkSHA256Writer: sha256.New(),
state: readChunkHeader,
}, ErrNone
}
@@ -184,7 +185,8 @@ type s3ChunkedReader struct {
reader *bufio.Reader
seedSignature string
seedDate time.Time
dataChunkRead bool
state chunkState
lastChunk bool
chunkSignature string
chunkSHA256Writer hash.Hash // Calculates sha256 of chunk data.
n uint64 // Unread bytes in chunk
@@ -207,99 +209,125 @@ func (cr *s3ChunkedReader) readS3ChunkHeader() {
if cr.n == 0 {
cr.err = io.EOF
}
// is the data part already read?, set this to false.
cr.dataChunkRead = false
// Reset sha256 hasher for a fresh start.
cr.chunkSHA256Writer.Reset()
// Save the incoming chunk signature.
cr.chunkSignature = string(hexChunkSignature)
}
// Validate if the underlying buffer has chunk header.
func (cr *s3ChunkedReader) s3ChunkHeaderAvailable() bool {
n := cr.reader.Buffered()
if n > 0 {
// Peek without seeking to look for trailing '\n'.
peek, _ := cr.reader.Peek(n)
return bytes.IndexByte(peek, '\n') >= 0
type chunkState int
const (
readChunkHeader chunkState = iota
readChunkTrailer
readChunk
verifyChunk
)
func (cs chunkState) String() string {
stateString := ""
switch cs {
case readChunkHeader:
stateString = "readChunkHeader"
case readChunkTrailer:
stateString = "readChunkTrailer"
case readChunk:
stateString = "readChunk"
case verifyChunk:
stateString = "verifyChunk"
}
return false
return stateString
}
// Read - implements `io.Reader`, which transparently decodes
// the incoming AWS Signature V4 streaming signature.
func (cr *s3ChunkedReader) Read(buf []byte) (n int, err error) {
for cr.err == nil {
if cr.n == 0 {
// For no chunk header available, we don't have to
// proceed to read again.
if n > 0 && !cr.s3ChunkHeaderAvailable() {
// We've read enough. Don't potentially block
// reading a new chunk header.
break
}
// If the chunk has been read, proceed to validate the rolling signature.
if cr.dataChunkRead {
// Calculate the hashed chunk.
hashedChunk := hex.EncodeToString(cr.chunkSHA256Writer.Sum(nil))
// Calculate the chunk signature.
newSignature := getChunkSignature(cr.seedSignature, cr.seedDate, hashedChunk)
if cr.chunkSignature != newSignature {
// Chunk signature doesn't match we return signature does not match.
cr.err = errSignatureMismatch
break
}
// Newly calculated signature becomes the seed for the next chunk
// this follows the chaining.
cr.seedSignature = newSignature
}
// Proceed to read the next chunk header.
for {
switch cr.state {
case readChunkHeader:
cr.readS3ChunkHeader()
continue
}
// With requested buffer of zero length, no need to read further.
if len(buf) == 0 {
break
}
rbuf := buf
// Make sure to read only the specified payload size, stagger
// the rest for subsequent requests.
if uint64(len(rbuf)) > cr.n {
rbuf = rbuf[:cr.n]
}
var n0 int
n0, cr.err = cr.reader.Read(rbuf)
// If we're at the end of a chunk.
if cr.n == 0 && cr.err == io.EOF {
cr.state = readChunkTrailer
cr.lastChunk = true
continue
}
if cr.err != nil {
return 0, cr.err
}
cr.state = readChunk
case readChunkTrailer:
cr.err = readCRLF(cr.reader)
if cr.err != nil {
return 0, errMalformedEncoding
}
cr.state = verifyChunk
case readChunk:
// There is no more space left in the request buffer.
if len(buf) == 0 {
return n, nil
}
rbuf := buf
// The request buffer is larger than the current chunk size.
// Read only the current chunk from the underlying reader.
if uint64(len(rbuf)) > cr.n {
rbuf = rbuf[:cr.n]
}
var n0 int
n0, cr.err = cr.reader.Read(rbuf)
if cr.err != nil {
// We have lesser than chunk size advertised in chunkHeader, this is 'unexpected'.
if cr.err == io.EOF {
cr.err = io.ErrUnexpectedEOF
}
return 0, cr.err
}
// Calculate sha256.
cr.chunkSHA256Writer.Write(rbuf[:n0])
// Set since we have read the chunk read.
cr.dataChunkRead = true
// Calculate sha256.
cr.chunkSHA256Writer.Write(rbuf[:n0])
// Update the bytes read into request buffer so far.
n += n0
buf = buf[n0:]
// Update bytes to be read of the current chunk before verifying chunk's signature.
cr.n -= uint64(n0)
n += n0
buf = buf[n0:]
// Decrements the 'cr.n' for future reads.
cr.n -= uint64(n0)
// If we're at the end of a chunk.
if cr.n == 0 && cr.err == nil {
// Read the next two bytes to verify if they are "\r\n".
cr.err = checkCRLF(cr.reader)
// If we're at the end of a chunk.
if cr.n == 0 {
cr.state = readChunkTrailer
continue
}
case verifyChunk:
// Calculate the hashed chunk.
hashedChunk := hex.EncodeToString(cr.chunkSHA256Writer.Sum(nil))
// Calculate the chunk signature.
newSignature := getChunkSignature(cr.seedSignature, cr.seedDate, hashedChunk)
if cr.chunkSignature != newSignature {
// Chunk signature doesn't match we return signature does not match.
cr.err = errSignatureMismatch
return 0, cr.err
}
// Newly calculated signature becomes the seed for the next chunk
// this follows the chaining.
cr.seedSignature = newSignature
cr.chunkSHA256Writer.Reset()
cr.state = readChunkHeader
if cr.lastChunk {
return n, nil
}
}
}
// Return number of bytes read, and error if any.
return n, cr.err
}
// checkCRLF - check if reader only has '\r\n' CRLF character.
// readCRLF - check if reader only has '\r\n' CRLF character.
// returns malformed encoding if it doesn't.
func checkCRLF(reader io.Reader) (err error) {
var buf = make([]byte, 2)
if _, err = io.ReadFull(reader, buf[:2]); err == nil {
if buf[0] != '\r' || buf[1] != '\n' {
err = errMalformedEncoding
}
func readCRLF(reader io.Reader) error {
buf := make([]byte, 2)
_, err := io.ReadFull(reader, buf[:2])
if err != nil {
return err
}
return err
if buf[0] != '\r' || buf[1] != '\n' {
return errMalformedEncoding
}
return nil
}
// Read a line of bytes (up to \n) from b.