2018-11-07 10:23:13 -08:00

735 lines
18 KiB
Go

package nsq
import (
"bufio"
"bytes"
"compress/flate"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/golang/snappy"
)
// IdentifyResponse represents the metadata
// returned from an IDENTIFY command to nsqd
type IdentifyResponse struct {
MaxRdyCount int64 `json:"max_rdy_count"`
TLSv1 bool `json:"tls_v1"`
Deflate bool `json:"deflate"`
Snappy bool `json:"snappy"`
AuthRequired bool `json:"auth_required"`
}
// AuthResponse represents the metadata
// returned from an AUTH command to nsqd
type AuthResponse struct {
Identity string `json:"identity"`
IdentityUrl string `json:"identity_url"`
PermissionCount int64 `json:"permission_count"`
}
type msgResponse struct {
msg *Message
cmd *Command
success bool
backoff bool
}
// Conn represents a connection to nsqd
//
// Conn exposes a set of callbacks for the
// various events that occur on a connection
type Conn struct {
// 64bit atomic vars need to be first for proper alignment on 32bit platforms
messagesInFlight int64
maxRdyCount int64
rdyCount int64
lastRdyCount int64
lastRdyTimestamp int64
lastMsgTimestamp int64
mtx sync.Mutex
config *Config
conn *net.TCPConn
tlsConn *tls.Conn
addr string
delegate ConnDelegate
logger logger
logLvl LogLevel
logFmt string
logGuard sync.RWMutex
r io.Reader
w io.Writer
cmdChan chan *Command
msgResponseChan chan *msgResponse
exitChan chan int
drainReady chan int
closeFlag int32
stopper sync.Once
wg sync.WaitGroup
readLoopRunning int32
}
// NewConn returns a new Conn instance
func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
if !config.initialized {
panic("Config must be created with NewConfig()")
}
return &Conn{
addr: addr,
config: config,
delegate: delegate,
maxRdyCount: 2500,
lastMsgTimestamp: time.Now().UnixNano(),
cmdChan: make(chan *Command),
msgResponseChan: make(chan *msgResponse),
exitChan: make(chan int),
drainReady: make(chan int),
}
}
// SetLogger assigns the logger to use as well as a level.
//
// The format parameter is expected to be a printf compatible string with
// a single %s argument. This is useful if you want to provide additional
// context to the log messages that the connection will print, the default
// is '(%s)'.
//
// The logger parameter is an interface that requires the following
// method to be implemented (such as the the stdlib log.Logger):
//
// Output(calldepth int, s string)
//
func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) {
c.logGuard.Lock()
defer c.logGuard.Unlock()
c.logger = l
c.logLvl = lvl
c.logFmt = format
if c.logFmt == "" {
c.logFmt = "(%s)"
}
}
func (c *Conn) getLogger() (logger, LogLevel, string) {
c.logGuard.RLock()
defer c.logGuard.RUnlock()
return c.logger, c.logLvl, c.logFmt
}
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}
conn, err := dialer.Dial("tcp", c.addr)
if err != nil {
return nil, err
}
c.conn = conn.(*net.TCPConn)
c.r = conn
c.w = conn
_, err = c.Write(MagicV2)
if err != nil {
c.Close()
return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err)
}
resp, err := c.identify()
if err != nil {
return nil, err
}
if resp != nil && resp.AuthRequired {
if c.config.AuthSecret == "" {
c.log(LogLevelError, "Auth Required")
return nil, errors.New("Auth Required")
}
err := c.auth(c.config.AuthSecret)
if err != nil {
c.log(LogLevelError, "Auth Failed %s", err)
return nil, err
}
}
c.wg.Add(2)
atomic.StoreInt32(&c.readLoopRunning, 1)
go c.readLoop()
go c.writeLoop()
return resp, nil
}
// Close idempotently initiates connection close
func (c *Conn) Close() error {
atomic.StoreInt32(&c.closeFlag, 1)
if c.conn != nil && atomic.LoadInt64(&c.messagesInFlight) == 0 {
return c.conn.CloseRead()
}
return nil
}
// IsClosing indicates whether or not the
// connection is currently in the processing of
// gracefully closing
func (c *Conn) IsClosing() bool {
return atomic.LoadInt32(&c.closeFlag) == 1
}
// RDY returns the current RDY count
func (c *Conn) RDY() int64 {
return atomic.LoadInt64(&c.rdyCount)
}
// LastRDY returns the previously set RDY count
func (c *Conn) LastRDY() int64 {
return atomic.LoadInt64(&c.lastRdyCount)
}
// SetRDY stores the specified RDY count
func (c *Conn) SetRDY(rdy int64) {
atomic.StoreInt64(&c.rdyCount, rdy)
atomic.StoreInt64(&c.lastRdyCount, rdy)
if rdy > 0 {
atomic.StoreInt64(&c.lastRdyTimestamp, time.Now().UnixNano())
}
}
// MaxRDY returns the nsqd negotiated maximum
// RDY count that it will accept for this connection
func (c *Conn) MaxRDY() int64 {
return c.maxRdyCount
}
func (c *Conn) LastRdyTime() time.Time {
return time.Unix(0, atomic.LoadInt64(&c.lastRdyTimestamp))
}
// LastMessageTime returns a time.Time representing
// the time at which the last message was received
func (c *Conn) LastMessageTime() time.Time {
return time.Unix(0, atomic.LoadInt64(&c.lastMsgTimestamp))
}
// RemoteAddr returns the configured destination nsqd address
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// String returns the fully-qualified address
func (c *Conn) String() string {
return c.addr
}
// Read performs a deadlined read on the underlying TCP connection
func (c *Conn) Read(p []byte) (int, error) {
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
return c.r.Read(p)
}
// Write performs a deadlined write on the underlying TCP connection
func (c *Conn) Write(p []byte) (int, error) {
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
return c.w.Write(p)
}
// WriteCommand is a goroutine safe method to write a Command
// to this connection, and flush.
func (c *Conn) WriteCommand(cmd *Command) error {
c.mtx.Lock()
_, err := cmd.WriteTo(c)
if err != nil {
goto exit
}
err = c.Flush()
exit:
c.mtx.Unlock()
if err != nil {
c.log(LogLevelError, "IO error - %s", err)
c.delegate.OnIOError(c, err)
}
return err
}
type flusher interface {
Flush() error
}
// Flush writes all buffered data to the underlying TCP connection
func (c *Conn) Flush() error {
if f, ok := c.w.(flusher); ok {
return f.Flush()
}
return nil
}
func (c *Conn) identify() (*IdentifyResponse, error) {
ci := make(map[string]interface{})
ci["client_id"] = c.config.ClientID
ci["hostname"] = c.config.Hostname
ci["user_agent"] = c.config.UserAgent
ci["short_id"] = c.config.ClientID // deprecated
ci["long_id"] = c.config.Hostname // deprecated
ci["tls_v1"] = c.config.TlsV1
ci["deflate"] = c.config.Deflate
ci["deflate_level"] = c.config.DeflateLevel
ci["snappy"] = c.config.Snappy
ci["feature_negotiation"] = true
if c.config.HeartbeatInterval == -1 {
ci["heartbeat_interval"] = -1
} else {
ci["heartbeat_interval"] = int64(c.config.HeartbeatInterval / time.Millisecond)
}
ci["sample_rate"] = c.config.SampleRate
ci["output_buffer_size"] = c.config.OutputBufferSize
if c.config.OutputBufferTimeout == -1 {
ci["output_buffer_timeout"] = -1
} else {
ci["output_buffer_timeout"] = int64(c.config.OutputBufferTimeout / time.Millisecond)
}
ci["msg_timeout"] = int64(c.config.MsgTimeout / time.Millisecond)
cmd, err := Identify(ci)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
err = c.WriteCommand(cmd)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
frameType, data, err := ReadUnpackedResponse(c)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
if frameType == FrameTypeError {
return nil, ErrIdentify{string(data)}
}
// check to see if the server was able to respond w/ capabilities
// i.e. it was a JSON response
if data[0] != '{' {
return nil, nil
}
resp := &IdentifyResponse{}
err = json.Unmarshal(data, resp)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
c.log(LogLevelDebug, "IDENTIFY response: %+v", resp)
c.maxRdyCount = resp.MaxRdyCount
if resp.TLSv1 {
c.log(LogLevelInfo, "upgrading to TLS")
err := c.upgradeTLS(c.config.TlsConfig)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
}
if resp.Deflate {
c.log(LogLevelInfo, "upgrading to Deflate")
err := c.upgradeDeflate(c.config.DeflateLevel)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
}
if resp.Snappy {
c.log(LogLevelInfo, "upgrading to Snappy")
err := c.upgradeSnappy()
if err != nil {
return nil, ErrIdentify{err.Error()}
}
}
// now that connection is bootstrapped, enable read buffering
// (and write buffering if it's not already capable of Flush())
c.r = bufio.NewReader(c.r)
if _, ok := c.w.(flusher); !ok {
c.w = bufio.NewWriter(c.w)
}
return resp, nil
}
func (c *Conn) upgradeTLS(tlsConf *tls.Config) error {
// create a local copy of the config to set ServerName for this connection
var conf tls.Config
if tlsConf != nil {
conf = *tlsConf
}
host, _, err := net.SplitHostPort(c.addr)
if err != nil {
return err
}
conf.ServerName = host
c.tlsConn = tls.Client(c.conn, &conf)
err = c.tlsConn.Handshake()
if err != nil {
return err
}
c.r = c.tlsConn
c.w = c.tlsConn
frameType, data, err := ReadUnpackedResponse(c)
if err != nil {
return err
}
if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) {
return errors.New("invalid response from TLS upgrade")
}
return nil
}
func (c *Conn) upgradeDeflate(level int) error {
conn := net.Conn(c.conn)
if c.tlsConn != nil {
conn = c.tlsConn
}
fw, _ := flate.NewWriter(conn, level)
c.r = flate.NewReader(conn)
c.w = fw
frameType, data, err := ReadUnpackedResponse(c)
if err != nil {
return err
}
if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) {
return errors.New("invalid response from Deflate upgrade")
}
return nil
}
func (c *Conn) upgradeSnappy() error {
conn := net.Conn(c.conn)
if c.tlsConn != nil {
conn = c.tlsConn
}
c.r = snappy.NewReader(conn)
c.w = snappy.NewWriter(conn)
frameType, data, err := ReadUnpackedResponse(c)
if err != nil {
return err
}
if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) {
return errors.New("invalid response from Snappy upgrade")
}
return nil
}
func (c *Conn) auth(secret string) error {
cmd, err := Auth(secret)
if err != nil {
return err
}
err = c.WriteCommand(cmd)
if err != nil {
return err
}
frameType, data, err := ReadUnpackedResponse(c)
if err != nil {
return err
}
if frameType == FrameTypeError {
return errors.New("Error authenticating " + string(data))
}
resp := &AuthResponse{}
err = json.Unmarshal(data, resp)
if err != nil {
return err
}
c.log(LogLevelInfo, "Auth accepted. Identity: %q %s Permissions: %d",
resp.Identity, resp.IdentityUrl, resp.PermissionCount)
return nil
}
func (c *Conn) readLoop() {
delegate := &connMessageDelegate{c}
for {
if atomic.LoadInt32(&c.closeFlag) == 1 {
goto exit
}
frameType, data, err := ReadUnpackedResponse(c)
if err != nil {
if err == io.EOF && atomic.LoadInt32(&c.closeFlag) == 1 {
goto exit
}
if !strings.Contains(err.Error(), "use of closed network connection") {
c.log(LogLevelError, "IO error - %s", err)
c.delegate.OnIOError(c, err)
}
goto exit
}
if frameType == FrameTypeResponse && bytes.Equal(data, []byte("_heartbeat_")) {
c.log(LogLevelDebug, "heartbeat received")
c.delegate.OnHeartbeat(c)
err := c.WriteCommand(Nop())
if err != nil {
c.log(LogLevelError, "IO error - %s", err)
c.delegate.OnIOError(c, err)
goto exit
}
continue
}
switch frameType {
case FrameTypeResponse:
c.delegate.OnResponse(c, data)
case FrameTypeMessage:
msg, err := DecodeMessage(data)
if err != nil {
c.log(LogLevelError, "IO error - %s", err)
c.delegate.OnIOError(c, err)
goto exit
}
msg.Delegate = delegate
msg.NSQDAddress = c.String()
atomic.AddInt64(&c.rdyCount, -1)
atomic.AddInt64(&c.messagesInFlight, 1)
atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano())
c.delegate.OnMessage(c, msg)
case FrameTypeError:
c.log(LogLevelError, "protocol error - %s", data)
c.delegate.OnError(c, data)
default:
c.log(LogLevelError, "IO error - %s", err)
c.delegate.OnIOError(c, fmt.Errorf("unknown frame type %d", frameType))
}
}
exit:
atomic.StoreInt32(&c.readLoopRunning, 0)
// start the connection close
messagesInFlight := atomic.LoadInt64(&c.messagesInFlight)
if messagesInFlight == 0 {
// if we exited readLoop with no messages in flight
// we need to explicitly trigger the close because
// writeLoop won't
c.close()
} else {
c.log(LogLevelWarning, "delaying close, %d outstanding messages", messagesInFlight)
}
c.wg.Done()
c.log(LogLevelInfo, "readLoop exiting")
}
func (c *Conn) writeLoop() {
for {
select {
case <-c.exitChan:
c.log(LogLevelInfo, "breaking out of writeLoop")
// Indicate drainReady because we will not pull any more off msgResponseChan
close(c.drainReady)
goto exit
case cmd := <-c.cmdChan:
err := c.WriteCommand(cmd)
if err != nil {
c.log(LogLevelError, "error sending command %s - %s", cmd, err)
c.close()
continue
}
case resp := <-c.msgResponseChan:
// Decrement this here so it is correct even if we can't respond to nsqd
msgsInFlight := atomic.AddInt64(&c.messagesInFlight, -1)
if resp.success {
c.log(LogLevelDebug, "FIN %s", resp.msg.ID)
c.delegate.OnMessageFinished(c, resp.msg)
c.delegate.OnResume(c)
} else {
c.log(LogLevelDebug, "REQ %s", resp.msg.ID)
c.delegate.OnMessageRequeued(c, resp.msg)
if resp.backoff {
c.delegate.OnBackoff(c)
} else {
c.delegate.OnContinue(c)
}
}
err := c.WriteCommand(resp.cmd)
if err != nil {
c.log(LogLevelError, "error sending command %s - %s", resp.cmd, err)
c.close()
continue
}
if msgsInFlight == 0 &&
atomic.LoadInt32(&c.closeFlag) == 1 {
c.close()
continue
}
}
}
exit:
c.wg.Done()
c.log(LogLevelInfo, "writeLoop exiting")
}
func (c *Conn) close() {
// a "clean" connection close is orchestrated as follows:
//
// 1. CLOSE cmd sent to nsqd
// 2. CLOSE_WAIT response received from nsqd
// 3. set c.closeFlag
// 4. readLoop() exits
// a. if messages-in-flight > 0 delay close()
// i. writeLoop() continues receiving on c.msgResponseChan chan
// x. when messages-in-flight == 0 call close()
// b. else call close() immediately
// 5. c.exitChan close
// a. writeLoop() exits
// i. c.drainReady close
// 6a. launch cleanup() goroutine (we're racing with intraprocess
// routed messages, see comments below)
// a. wait on c.drainReady
// b. loop and receive on c.msgResponseChan chan
// until messages-in-flight == 0
// i. ensure that readLoop has exited
// 6b. launch waitForCleanup() goroutine
// b. wait on waitgroup (covers readLoop() and writeLoop()
// and cleanup goroutine)
// c. underlying TCP connection close
// d. trigger Delegate OnClose()
//
c.stopper.Do(func() {
c.log(LogLevelInfo, "beginning close")
close(c.exitChan)
c.conn.CloseRead()
c.wg.Add(1)
go c.cleanup()
go c.waitForCleanup()
})
}
func (c *Conn) cleanup() {
<-c.drainReady
ticker := time.NewTicker(100 * time.Millisecond)
lastWarning := time.Now()
// writeLoop has exited, drain any remaining in flight messages
for {
// we're racing with readLoop which potentially has a message
// for handling so infinitely loop until messagesInFlight == 0
// and readLoop has exited
var msgsInFlight int64
select {
case <-c.msgResponseChan:
msgsInFlight = atomic.AddInt64(&c.messagesInFlight, -1)
case <-ticker.C:
msgsInFlight = atomic.LoadInt64(&c.messagesInFlight)
}
if msgsInFlight > 0 {
if time.Now().Sub(lastWarning) > time.Second {
c.log(LogLevelWarning, "draining... waiting for %d messages in flight", msgsInFlight)
lastWarning = time.Now()
}
continue
}
// until the readLoop has exited we cannot be sure that there
// still won't be a race
if atomic.LoadInt32(&c.readLoopRunning) == 1 {
if time.Now().Sub(lastWarning) > time.Second {
c.log(LogLevelWarning, "draining... readLoop still running")
lastWarning = time.Now()
}
continue
}
goto exit
}
exit:
ticker.Stop()
c.wg.Done()
c.log(LogLevelInfo, "finished draining, cleanup exiting")
}
func (c *Conn) waitForCleanup() {
// this blocks until readLoop and writeLoop
// (and cleanup goroutine above) have exited
c.wg.Wait()
c.conn.CloseWrite()
c.log(LogLevelInfo, "clean close complete")
c.delegate.OnClose(c)
}
func (c *Conn) onMessageFinish(m *Message) {
c.msgResponseChan <- &msgResponse{msg: m, cmd: Finish(m.ID), success: true}
}
func (c *Conn) onMessageRequeue(m *Message, delay time.Duration, backoff bool) {
if delay == -1 {
// linear delay
delay = c.config.DefaultRequeueDelay * time.Duration(m.Attempts)
// bound the requeueDelay to configured max
if delay > c.config.MaxRequeueDelay {
delay = c.config.MaxRequeueDelay
}
}
c.msgResponseChan <- &msgResponse{msg: m, cmd: Requeue(m.ID, delay), success: false, backoff: backoff}
}
func (c *Conn) onMessageTouch(m *Message) {
select {
case c.cmdChan <- Touch(m.ID):
case <-c.exitChan:
}
}
func (c *Conn) log(lvl LogLevel, line string, args ...interface{}) {
logger, logLvl, logFmt := c.getLogger()
if logger == nil {
return
}
if logLvl > lvl {
return
}
logger.Output(2, fmt.Sprintf("%-4s %s %s", lvl,
fmt.Sprintf(logFmt, c.String()),
fmt.Sprintf(line, args...)))
}