package nsq import ( "bufio" "bytes" "compress/flate" "crypto/tls" "encoding/json" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "time" "github.com/golang/snappy" ) // IdentifyResponse represents the metadata // returned from an IDENTIFY command to nsqd type IdentifyResponse struct { MaxRdyCount int64 `json:"max_rdy_count"` TLSv1 bool `json:"tls_v1"` Deflate bool `json:"deflate"` Snappy bool `json:"snappy"` AuthRequired bool `json:"auth_required"` } // AuthResponse represents the metadata // returned from an AUTH command to nsqd type AuthResponse struct { Identity string `json:"identity"` IdentityUrl string `json:"identity_url"` PermissionCount int64 `json:"permission_count"` } type msgResponse struct { msg *Message cmd *Command success bool backoff bool } // Conn represents a connection to nsqd // // Conn exposes a set of callbacks for the // various events that occur on a connection type Conn struct { // 64bit atomic vars need to be first for proper alignment on 32bit platforms messagesInFlight int64 maxRdyCount int64 rdyCount int64 lastRdyCount int64 lastRdyTimestamp int64 lastMsgTimestamp int64 mtx sync.Mutex config *Config conn *net.TCPConn tlsConn *tls.Conn addr string delegate ConnDelegate logger logger logLvl LogLevel logFmt string logGuard sync.RWMutex r io.Reader w io.Writer cmdChan chan *Command msgResponseChan chan *msgResponse exitChan chan int drainReady chan int closeFlag int32 stopper sync.Once wg sync.WaitGroup readLoopRunning int32 } // NewConn returns a new Conn instance func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn { if !config.initialized { panic("Config must be created with NewConfig()") } return &Conn{ addr: addr, config: config, delegate: delegate, maxRdyCount: 2500, lastMsgTimestamp: time.Now().UnixNano(), cmdChan: make(chan *Command), msgResponseChan: make(chan *msgResponse), exitChan: make(chan int), drainReady: make(chan int), } } // SetLogger assigns the logger to use as well as a level. // // The format parameter is expected to be a printf compatible string with // a single %s argument. This is useful if you want to provide additional // context to the log messages that the connection will print, the default // is '(%s)'. // // The logger parameter is an interface that requires the following // method to be implemented (such as the the stdlib log.Logger): // // Output(calldepth int, s string) // func (c *Conn) SetLogger(l logger, lvl LogLevel, format string) { c.logGuard.Lock() defer c.logGuard.Unlock() c.logger = l c.logLvl = lvl c.logFmt = format if c.logFmt == "" { c.logFmt = "(%s)" } } func (c *Conn) getLogger() (logger, LogLevel, string) { c.logGuard.RLock() defer c.logGuard.RUnlock() return c.logger, c.logLvl, c.logFmt } // Connect dials and bootstraps the nsqd connection // (including IDENTIFY) and returns the IdentifyResponse func (c *Conn) Connect() (*IdentifyResponse, error) { dialer := &net.Dialer{ LocalAddr: c.config.LocalAddr, Timeout: c.config.DialTimeout, } conn, err := dialer.Dial("tcp", c.addr) if err != nil { return nil, err } c.conn = conn.(*net.TCPConn) c.r = conn c.w = conn _, err = c.Write(MagicV2) if err != nil { c.Close() return nil, fmt.Errorf("[%s] failed to write magic - %s", c.addr, err) } resp, err := c.identify() if err != nil { return nil, err } if resp != nil && resp.AuthRequired { if c.config.AuthSecret == "" { c.log(LogLevelError, "Auth Required") return nil, errors.New("Auth Required") } err := c.auth(c.config.AuthSecret) if err != nil { c.log(LogLevelError, "Auth Failed %s", err) return nil, err } } c.wg.Add(2) atomic.StoreInt32(&c.readLoopRunning, 1) go c.readLoop() go c.writeLoop() return resp, nil } // Close idempotently initiates connection close func (c *Conn) Close() error { atomic.StoreInt32(&c.closeFlag, 1) if c.conn != nil && atomic.LoadInt64(&c.messagesInFlight) == 0 { return c.conn.CloseRead() } return nil } // IsClosing indicates whether or not the // connection is currently in the processing of // gracefully closing func (c *Conn) IsClosing() bool { return atomic.LoadInt32(&c.closeFlag) == 1 } // RDY returns the current RDY count func (c *Conn) RDY() int64 { return atomic.LoadInt64(&c.rdyCount) } // LastRDY returns the previously set RDY count func (c *Conn) LastRDY() int64 { return atomic.LoadInt64(&c.lastRdyCount) } // SetRDY stores the specified RDY count func (c *Conn) SetRDY(rdy int64) { atomic.StoreInt64(&c.rdyCount, rdy) atomic.StoreInt64(&c.lastRdyCount, rdy) if rdy > 0 { atomic.StoreInt64(&c.lastRdyTimestamp, time.Now().UnixNano()) } } // MaxRDY returns the nsqd negotiated maximum // RDY count that it will accept for this connection func (c *Conn) MaxRDY() int64 { return c.maxRdyCount } func (c *Conn) LastRdyTime() time.Time { return time.Unix(0, atomic.LoadInt64(&c.lastRdyTimestamp)) } // LastMessageTime returns a time.Time representing // the time at which the last message was received func (c *Conn) LastMessageTime() time.Time { return time.Unix(0, atomic.LoadInt64(&c.lastMsgTimestamp)) } // RemoteAddr returns the configured destination nsqd address func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // String returns the fully-qualified address func (c *Conn) String() string { return c.addr } // Read performs a deadlined read on the underlying TCP connection func (c *Conn) Read(p []byte) (int, error) { c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)) return c.r.Read(p) } // Write performs a deadlined write on the underlying TCP connection func (c *Conn) Write(p []byte) (int, error) { c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)) return c.w.Write(p) } // WriteCommand is a goroutine safe method to write a Command // to this connection, and flush. func (c *Conn) WriteCommand(cmd *Command) error { c.mtx.Lock() _, err := cmd.WriteTo(c) if err != nil { goto exit } err = c.Flush() exit: c.mtx.Unlock() if err != nil { c.log(LogLevelError, "IO error - %s", err) c.delegate.OnIOError(c, err) } return err } type flusher interface { Flush() error } // Flush writes all buffered data to the underlying TCP connection func (c *Conn) Flush() error { if f, ok := c.w.(flusher); ok { return f.Flush() } return nil } func (c *Conn) identify() (*IdentifyResponse, error) { ci := make(map[string]interface{}) ci["client_id"] = c.config.ClientID ci["hostname"] = c.config.Hostname ci["user_agent"] = c.config.UserAgent ci["short_id"] = c.config.ClientID // deprecated ci["long_id"] = c.config.Hostname // deprecated ci["tls_v1"] = c.config.TlsV1 ci["deflate"] = c.config.Deflate ci["deflate_level"] = c.config.DeflateLevel ci["snappy"] = c.config.Snappy ci["feature_negotiation"] = true if c.config.HeartbeatInterval == -1 { ci["heartbeat_interval"] = -1 } else { ci["heartbeat_interval"] = int64(c.config.HeartbeatInterval / time.Millisecond) } ci["sample_rate"] = c.config.SampleRate ci["output_buffer_size"] = c.config.OutputBufferSize if c.config.OutputBufferTimeout == -1 { ci["output_buffer_timeout"] = -1 } else { ci["output_buffer_timeout"] = int64(c.config.OutputBufferTimeout / time.Millisecond) } ci["msg_timeout"] = int64(c.config.MsgTimeout / time.Millisecond) cmd, err := Identify(ci) if err != nil { return nil, ErrIdentify{err.Error()} } err = c.WriteCommand(cmd) if err != nil { return nil, ErrIdentify{err.Error()} } frameType, data, err := ReadUnpackedResponse(c) if err != nil { return nil, ErrIdentify{err.Error()} } if frameType == FrameTypeError { return nil, ErrIdentify{string(data)} } // check to see if the server was able to respond w/ capabilities // i.e. it was a JSON response if data[0] != '{' { return nil, nil } resp := &IdentifyResponse{} err = json.Unmarshal(data, resp) if err != nil { return nil, ErrIdentify{err.Error()} } c.log(LogLevelDebug, "IDENTIFY response: %+v", resp) c.maxRdyCount = resp.MaxRdyCount if resp.TLSv1 { c.log(LogLevelInfo, "upgrading to TLS") err := c.upgradeTLS(c.config.TlsConfig) if err != nil { return nil, ErrIdentify{err.Error()} } } if resp.Deflate { c.log(LogLevelInfo, "upgrading to Deflate") err := c.upgradeDeflate(c.config.DeflateLevel) if err != nil { return nil, ErrIdentify{err.Error()} } } if resp.Snappy { c.log(LogLevelInfo, "upgrading to Snappy") err := c.upgradeSnappy() if err != nil { return nil, ErrIdentify{err.Error()} } } // now that connection is bootstrapped, enable read buffering // (and write buffering if it's not already capable of Flush()) c.r = bufio.NewReader(c.r) if _, ok := c.w.(flusher); !ok { c.w = bufio.NewWriter(c.w) } return resp, nil } func (c *Conn) upgradeTLS(tlsConf *tls.Config) error { // create a local copy of the config to set ServerName for this connection var conf tls.Config if tlsConf != nil { conf = *tlsConf } host, _, err := net.SplitHostPort(c.addr) if err != nil { return err } conf.ServerName = host c.tlsConn = tls.Client(c.conn, &conf) err = c.tlsConn.Handshake() if err != nil { return err } c.r = c.tlsConn c.w = c.tlsConn frameType, data, err := ReadUnpackedResponse(c) if err != nil { return err } if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) { return errors.New("invalid response from TLS upgrade") } return nil } func (c *Conn) upgradeDeflate(level int) error { conn := net.Conn(c.conn) if c.tlsConn != nil { conn = c.tlsConn } fw, _ := flate.NewWriter(conn, level) c.r = flate.NewReader(conn) c.w = fw frameType, data, err := ReadUnpackedResponse(c) if err != nil { return err } if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) { return errors.New("invalid response from Deflate upgrade") } return nil } func (c *Conn) upgradeSnappy() error { conn := net.Conn(c.conn) if c.tlsConn != nil { conn = c.tlsConn } c.r = snappy.NewReader(conn) c.w = snappy.NewWriter(conn) frameType, data, err := ReadUnpackedResponse(c) if err != nil { return err } if frameType != FrameTypeResponse || !bytes.Equal(data, []byte("OK")) { return errors.New("invalid response from Snappy upgrade") } return nil } func (c *Conn) auth(secret string) error { cmd, err := Auth(secret) if err != nil { return err } err = c.WriteCommand(cmd) if err != nil { return err } frameType, data, err := ReadUnpackedResponse(c) if err != nil { return err } if frameType == FrameTypeError { return errors.New("Error authenticating " + string(data)) } resp := &AuthResponse{} err = json.Unmarshal(data, resp) if err != nil { return err } c.log(LogLevelInfo, "Auth accepted. Identity: %q %s Permissions: %d", resp.Identity, resp.IdentityUrl, resp.PermissionCount) return nil } func (c *Conn) readLoop() { delegate := &connMessageDelegate{c} for { if atomic.LoadInt32(&c.closeFlag) == 1 { goto exit } frameType, data, err := ReadUnpackedResponse(c) if err != nil { if err == io.EOF && atomic.LoadInt32(&c.closeFlag) == 1 { goto exit } if !strings.Contains(err.Error(), "use of closed network connection") { c.log(LogLevelError, "IO error - %s", err) c.delegate.OnIOError(c, err) } goto exit } if frameType == FrameTypeResponse && bytes.Equal(data, []byte("_heartbeat_")) { c.log(LogLevelDebug, "heartbeat received") c.delegate.OnHeartbeat(c) err := c.WriteCommand(Nop()) if err != nil { c.log(LogLevelError, "IO error - %s", err) c.delegate.OnIOError(c, err) goto exit } continue } switch frameType { case FrameTypeResponse: c.delegate.OnResponse(c, data) case FrameTypeMessage: msg, err := DecodeMessage(data) if err != nil { c.log(LogLevelError, "IO error - %s", err) c.delegate.OnIOError(c, err) goto exit } msg.Delegate = delegate msg.NSQDAddress = c.String() atomic.AddInt64(&c.rdyCount, -1) atomic.AddInt64(&c.messagesInFlight, 1) atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano()) c.delegate.OnMessage(c, msg) case FrameTypeError: c.log(LogLevelError, "protocol error - %s", data) c.delegate.OnError(c, data) default: c.log(LogLevelError, "IO error - %s", err) c.delegate.OnIOError(c, fmt.Errorf("unknown frame type %d", frameType)) } } exit: atomic.StoreInt32(&c.readLoopRunning, 0) // start the connection close messagesInFlight := atomic.LoadInt64(&c.messagesInFlight) if messagesInFlight == 0 { // if we exited readLoop with no messages in flight // we need to explicitly trigger the close because // writeLoop won't c.close() } else { c.log(LogLevelWarning, "delaying close, %d outstanding messages", messagesInFlight) } c.wg.Done() c.log(LogLevelInfo, "readLoop exiting") } func (c *Conn) writeLoop() { for { select { case <-c.exitChan: c.log(LogLevelInfo, "breaking out of writeLoop") // Indicate drainReady because we will not pull any more off msgResponseChan close(c.drainReady) goto exit case cmd := <-c.cmdChan: err := c.WriteCommand(cmd) if err != nil { c.log(LogLevelError, "error sending command %s - %s", cmd, err) c.close() continue } case resp := <-c.msgResponseChan: // Decrement this here so it is correct even if we can't respond to nsqd msgsInFlight := atomic.AddInt64(&c.messagesInFlight, -1) if resp.success { c.log(LogLevelDebug, "FIN %s", resp.msg.ID) c.delegate.OnMessageFinished(c, resp.msg) c.delegate.OnResume(c) } else { c.log(LogLevelDebug, "REQ %s", resp.msg.ID) c.delegate.OnMessageRequeued(c, resp.msg) if resp.backoff { c.delegate.OnBackoff(c) } else { c.delegate.OnContinue(c) } } err := c.WriteCommand(resp.cmd) if err != nil { c.log(LogLevelError, "error sending command %s - %s", resp.cmd, err) c.close() continue } if msgsInFlight == 0 && atomic.LoadInt32(&c.closeFlag) == 1 { c.close() continue } } } exit: c.wg.Done() c.log(LogLevelInfo, "writeLoop exiting") } func (c *Conn) close() { // a "clean" connection close is orchestrated as follows: // // 1. CLOSE cmd sent to nsqd // 2. CLOSE_WAIT response received from nsqd // 3. set c.closeFlag // 4. readLoop() exits // a. if messages-in-flight > 0 delay close() // i. writeLoop() continues receiving on c.msgResponseChan chan // x. when messages-in-flight == 0 call close() // b. else call close() immediately // 5. c.exitChan close // a. writeLoop() exits // i. c.drainReady close // 6a. launch cleanup() goroutine (we're racing with intraprocess // routed messages, see comments below) // a. wait on c.drainReady // b. loop and receive on c.msgResponseChan chan // until messages-in-flight == 0 // i. ensure that readLoop has exited // 6b. launch waitForCleanup() goroutine // b. wait on waitgroup (covers readLoop() and writeLoop() // and cleanup goroutine) // c. underlying TCP connection close // d. trigger Delegate OnClose() // c.stopper.Do(func() { c.log(LogLevelInfo, "beginning close") close(c.exitChan) c.conn.CloseRead() c.wg.Add(1) go c.cleanup() go c.waitForCleanup() }) } func (c *Conn) cleanup() { <-c.drainReady ticker := time.NewTicker(100 * time.Millisecond) lastWarning := time.Now() // writeLoop has exited, drain any remaining in flight messages for { // we're racing with readLoop which potentially has a message // for handling so infinitely loop until messagesInFlight == 0 // and readLoop has exited var msgsInFlight int64 select { case <-c.msgResponseChan: msgsInFlight = atomic.AddInt64(&c.messagesInFlight, -1) case <-ticker.C: msgsInFlight = atomic.LoadInt64(&c.messagesInFlight) } if msgsInFlight > 0 { if time.Now().Sub(lastWarning) > time.Second { c.log(LogLevelWarning, "draining... waiting for %d messages in flight", msgsInFlight) lastWarning = time.Now() } continue } // until the readLoop has exited we cannot be sure that there // still won't be a race if atomic.LoadInt32(&c.readLoopRunning) == 1 { if time.Now().Sub(lastWarning) > time.Second { c.log(LogLevelWarning, "draining... readLoop still running") lastWarning = time.Now() } continue } goto exit } exit: ticker.Stop() c.wg.Done() c.log(LogLevelInfo, "finished draining, cleanup exiting") } func (c *Conn) waitForCleanup() { // this blocks until readLoop and writeLoop // (and cleanup goroutine above) have exited c.wg.Wait() c.conn.CloseWrite() c.log(LogLevelInfo, "clean close complete") c.delegate.OnClose(c) } func (c *Conn) onMessageFinish(m *Message) { c.msgResponseChan <- &msgResponse{msg: m, cmd: Finish(m.ID), success: true} } func (c *Conn) onMessageRequeue(m *Message, delay time.Duration, backoff bool) { if delay == -1 { // linear delay delay = c.config.DefaultRequeueDelay * time.Duration(m.Attempts) // bound the requeueDelay to configured max if delay > c.config.MaxRequeueDelay { delay = c.config.MaxRequeueDelay } } c.msgResponseChan <- &msgResponse{msg: m, cmd: Requeue(m.ID, delay), success: false, backoff: backoff} } func (c *Conn) onMessageTouch(m *Message) { select { case c.cmdChan <- Touch(m.ID): case <-c.exitChan: } } func (c *Conn) log(lvl LogLevel, line string, args ...interface{}) { logger, logLvl, logFmt := c.getLogger() if logger == nil { return } if logLvl > lvl { return } logger.Output(2, fmt.Sprintf("%-4s %s %s", lvl, fmt.Sprintf(logFmt, c.String()), fmt.Sprintf(line, args...))) }