2018-11-07 10:23:13 -08:00

672 lines
19 KiB
Go

package nsq
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"math"
"math/rand"
"net"
"os"
"reflect"
"strconv"
"strings"
"sync"
"time"
"unsafe"
)
// Define handlers for setting config defaults, and setting config values from command line arguments or config files
type configHandler interface {
HandlesOption(c *Config, option string) bool
Set(c *Config, option string, value interface{}) error
Validate(c *Config) error
}
type defaultsHandler interface {
SetDefaults(c *Config) error
}
// BackoffStrategy defines a strategy for calculating the duration of time
// a consumer should backoff for a given attempt
type BackoffStrategy interface {
Calculate(attempt int) time.Duration
}
// ExponentialStrategy implements an exponential backoff strategy (default)
type ExponentialStrategy struct {
cfg *Config
}
// Calculate returns a duration of time: 2 ^ attempt
func (s *ExponentialStrategy) Calculate(attempt int) time.Duration {
backoffDuration := s.cfg.BackoffMultiplier *
time.Duration(math.Pow(2, float64(attempt)))
return backoffDuration
}
func (s *ExponentialStrategy) setConfig(cfg *Config) {
s.cfg = cfg
}
// FullJitterStrategy implements http://www.awsarchitectureblog.com/2015/03/backoff.html
type FullJitterStrategy struct {
cfg *Config
rngOnce sync.Once
rng *rand.Rand
}
// Calculate returns a random duration of time [0, 2 ^ attempt]
func (s *FullJitterStrategy) Calculate(attempt int) time.Duration {
// lazily initialize the RNG
s.rngOnce.Do(func() {
if s.rng != nil {
return
}
s.rng = rand.New(rand.NewSource(time.Now().UnixNano()))
})
backoffDuration := s.cfg.BackoffMultiplier *
time.Duration(math.Pow(2, float64(attempt)))
return time.Duration(s.rng.Intn(int(backoffDuration)))
}
func (s *FullJitterStrategy) setConfig(cfg *Config) {
s.cfg = cfg
}
// Config is a struct of NSQ options
//
// The only valid way to create a Config is via NewConfig, using a struct literal will panic.
// After Config is passed into a high-level type (like Consumer, Producer, etc.) the values are no
// longer mutable (they are copied).
//
// Use Set(option string, value interface{}) as an alternate way to set parameters
type Config struct {
initialized bool
// used to Initialize, Validate
configHandlers []configHandler
DialTimeout time.Duration `opt:"dial_timeout" default:"1s"`
// Deadlines for network reads and writes
ReadTimeout time.Duration `opt:"read_timeout" min:"100ms" max:"5m" default:"60s"`
WriteTimeout time.Duration `opt:"write_timeout" min:"100ms" max:"5m" default:"1s"`
// LocalAddr is the local address to use when dialing an nsqd.
// If empty, a local address is automatically chosen.
LocalAddr net.Addr `opt:"local_addr"`
// Duration between polling lookupd for new producers, and fractional jitter to add to
// the lookupd pool loop. this helps evenly distribute requests even if multiple consumers
// restart at the same time
//
// NOTE: when not using nsqlookupd, LookupdPollInterval represents the duration of time between
// reconnection attempts
LookupdPollInterval time.Duration `opt:"lookupd_poll_interval" min:"10ms" max:"5m" default:"60s"`
LookupdPollJitter float64 `opt:"lookupd_poll_jitter" min:"0" max:"1" default:"0.3"`
// Maximum duration when REQueueing (for doubling of deferred requeue)
MaxRequeueDelay time.Duration `opt:"max_requeue_delay" min:"0" max:"60m" default:"15m"`
DefaultRequeueDelay time.Duration `opt:"default_requeue_delay" min:"0" max:"60m" default:"90s"`
// Backoff strategy, defaults to exponential backoff. Overwrite this to define alternative backoff algrithms.
BackoffStrategy BackoffStrategy `opt:"backoff_strategy" default:"exponential"`
// Maximum amount of time to backoff when processing fails 0 == no backoff
MaxBackoffDuration time.Duration `opt:"max_backoff_duration" min:"0" max:"60m" default:"2m"`
// Unit of time for calculating consumer backoff
BackoffMultiplier time.Duration `opt:"backoff_multiplier" min:"0" max:"60m" default:"1s"`
// Maximum number of times this consumer will attempt to process a message before giving up
MaxAttempts uint16 `opt:"max_attempts" min:"0" max:"65535" default:"5"`
// Duration to wait for a message from an nsqd when in a state where RDY
// counts are re-distributed (e.g. max_in_flight < num_producers)
LowRdyIdleTimeout time.Duration `opt:"low_rdy_idle_timeout" min:"1s" max:"5m" default:"10s"`
// Duration to wait until redistributing RDY for an nsqd regardless of LowRdyIdleTimeout
LowRdyTimeout time.Duration `opt:"low_rdy_timeout" min:"1s" max:"5m" default:"30s"`
// Duration between redistributing max-in-flight to connections
RDYRedistributeInterval time.Duration `opt:"rdy_redistribute_interval" min:"1ms" max:"5s" default:"5s"`
// Identifiers sent to nsqd representing this client
// UserAgent is in the spirit of HTTP (default: "<client_library_name>/<version>")
ClientID string `opt:"client_id"` // (defaults: short hostname)
Hostname string `opt:"hostname"`
UserAgent string `opt:"user_agent"`
// Duration of time between heartbeats. This must be less than ReadTimeout
HeartbeatInterval time.Duration `opt:"heartbeat_interval" default:"30s"`
// Integer percentage to sample the channel (requires nsqd 0.2.25+)
SampleRate int32 `opt:"sample_rate" min:"0" max:"99"`
// To set TLS config, use the following options:
//
// tls_v1 - Bool enable TLS negotiation
// tls_root_ca_file - String path to file containing root CA
// tls_insecure_skip_verify - Bool indicates whether this client should verify server certificates
// tls_cert - String path to file containing public key for certificate
// tls_key - String path to file containing private key for certificate
// tls_min_version - String indicating the minimum version of tls acceptable ('ssl3.0', 'tls1.0', 'tls1.1', 'tls1.2')
//
TlsV1 bool `opt:"tls_v1"`
TlsConfig *tls.Config `opt:"tls_config"`
// Compression Settings
Deflate bool `opt:"deflate"`
DeflateLevel int `opt:"deflate_level" min:"1" max:"9" default:"6"`
Snappy bool `opt:"snappy"`
// Size of the buffer (in bytes) used by nsqd for buffering writes to this connection
OutputBufferSize int64 `opt:"output_buffer_size" default:"16384"`
// Timeout used by nsqd before flushing buffered writes (set to 0 to disable).
//
// WARNING: configuring clients with an extremely low
// (< 25ms) output_buffer_timeout has a significant effect
// on nsqd CPU usage (particularly with > 50 clients connected).
OutputBufferTimeout time.Duration `opt:"output_buffer_timeout" default:"250ms"`
// Maximum number of messages to allow in flight (concurrency knob)
MaxInFlight int `opt:"max_in_flight" min:"0" default:"1"`
// The server-side message timeout for messages delivered to this client
MsgTimeout time.Duration `opt:"msg_timeout" min:"0"`
// secret for nsqd authentication (requires nsqd 0.2.29+)
AuthSecret string `opt:"auth_secret"`
}
// NewConfig returns a new default nsq configuration.
//
// This must be used to initialize Config structs. Values can be set directly, or through Config.Set()
func NewConfig() *Config {
c := &Config{
configHandlers: []configHandler{&structTagsConfig{}, &tlsConfig{}},
initialized: true,
}
if err := c.setDefaults(); err != nil {
panic(err.Error())
}
return c
}
// Set takes an option as a string and a value as an interface and
// attempts to set the appropriate configuration option.
//
// It attempts to coerce the value into the right format depending on the named
// option and the underlying type of the value passed in.
//
// Calls to Set() that take a time.Duration as an argument can be input as:
//
// "1000ms" (a string parsed by time.ParseDuration())
// 1000 (an integer interpreted as milliseconds)
// 1000*time.Millisecond (a literal time.Duration value)
//
// Calls to Set() that take bool can be input as:
//
// "true" (a string parsed by strconv.ParseBool())
// true (a boolean)
// 1 (an int where 1 == true and 0 == false)
//
// It returns an error for an invalid option or value.
func (c *Config) Set(option string, value interface{}) error {
c.assertInitialized()
option = strings.Replace(option, "-", "_", -1)
for _, h := range c.configHandlers {
if h.HandlesOption(c, option) {
return h.Set(c, option, value)
}
}
return fmt.Errorf("invalid option %s", option)
}
func (c *Config) assertInitialized() {
if !c.initialized {
panic("Config{} must be created with NewConfig()")
}
}
// Validate checks that all values are within specified min/max ranges
func (c *Config) Validate() error {
c.assertInitialized()
for _, h := range c.configHandlers {
if err := h.Validate(c); err != nil {
return err
}
}
return nil
}
func (c *Config) setDefaults() error {
for _, h := range c.configHandlers {
hh, ok := h.(defaultsHandler)
if ok {
if err := hh.SetDefaults(c); err != nil {
return err
}
}
}
return nil
}
type structTagsConfig struct{}
// Handle options that are listed in StructTags
func (h *structTagsConfig) HandlesOption(c *Config, option string) bool {
val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
opt := field.Tag.Get("opt")
if opt == option {
return true
}
}
return false
}
// Set values based on parameters in StructTags
func (h *structTagsConfig) Set(c *Config, option string, value interface{}) error {
val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
opt := field.Tag.Get("opt")
if option != opt {
continue
}
min := field.Tag.Get("min")
max := field.Tag.Get("max")
fieldVal := val.FieldByName(field.Name)
dest := unsafeValueOf(fieldVal)
coercedVal, err := coerce(value, field.Type)
if err != nil {
return fmt.Errorf("failed to coerce option %s (%v) - %s",
option, value, err)
}
if min != "" {
coercedMinVal, _ := coerce(min, field.Type)
if valueCompare(coercedVal, coercedMinVal) == -1 {
return fmt.Errorf("invalid %s ! %v < %v",
option, coercedVal.Interface(), coercedMinVal.Interface())
}
}
if max != "" {
coercedMaxVal, _ := coerce(max, field.Type)
if valueCompare(coercedVal, coercedMaxVal) == 1 {
return fmt.Errorf("invalid %s ! %v > %v",
option, coercedVal.Interface(), coercedMaxVal.Interface())
}
}
if coercedVal.Type().String() == "nsq.BackoffStrategy" {
v := coercedVal.Interface().(BackoffStrategy)
if v, ok := v.(interface {
setConfig(*Config)
}); ok {
v.setConfig(c)
}
}
dest.Set(coercedVal)
return nil
}
return fmt.Errorf("unknown option %s", option)
}
func (h *structTagsConfig) SetDefaults(c *Config) error {
val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
opt := field.Tag.Get("opt")
defaultVal := field.Tag.Get("default")
if defaultVal == "" || opt == "" {
continue
}
if err := c.Set(opt, defaultVal); err != nil {
return err
}
}
hostname, err := os.Hostname()
if err != nil {
log.Fatalf("ERROR: unable to get hostname %s", err.Error())
}
c.ClientID = strings.Split(hostname, ".")[0]
c.Hostname = hostname
c.UserAgent = fmt.Sprintf("go-nsq/%s", VERSION)
return nil
}
func (h *structTagsConfig) Validate(c *Config) error {
val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
min := field.Tag.Get("min")
max := field.Tag.Get("max")
if min == "" && max == "" {
continue
}
value := val.FieldByName(field.Name)
if min != "" {
coercedMinVal, _ := coerce(min, field.Type)
if valueCompare(value, coercedMinVal) == -1 {
return fmt.Errorf("invalid %s ! %v < %v",
field.Name, value.Interface(), coercedMinVal.Interface())
}
}
if max != "" {
coercedMaxVal, _ := coerce(max, field.Type)
if valueCompare(value, coercedMaxVal) == 1 {
return fmt.Errorf("invalid %s ! %v > %v",
field.Name, value.Interface(), coercedMaxVal.Interface())
}
}
}
if c.HeartbeatInterval > c.ReadTimeout {
return fmt.Errorf("HeartbeatInterval %v must be less than ReadTimeout %v", c.HeartbeatInterval, c.ReadTimeout)
}
return nil
}
// Parsing for higher order TLS settings
type tlsConfig struct {
certFile string
keyFile string
}
func (t *tlsConfig) HandlesOption(c *Config, option string) bool {
switch option {
case "tls_root_ca_file", "tls_insecure_skip_verify", "tls_cert", "tls_key", "tls_min_version":
return true
}
return false
}
func (t *tlsConfig) Set(c *Config, option string, value interface{}) error {
if c.TlsConfig == nil {
c.TlsConfig = &tls.Config{
MinVersion: tls.VersionTLS10,
MaxVersion: tls.VersionTLS12, // enable TLS_FALLBACK_SCSV prior to Go 1.5: https://go-review.googlesource.com/#/c/1776/
}
}
val := reflect.ValueOf(c.TlsConfig).Elem()
switch option {
case "tls_cert", "tls_key":
if option == "tls_cert" {
t.certFile = value.(string)
} else {
t.keyFile = value.(string)
}
if t.certFile != "" && t.keyFile != "" && len(c.TlsConfig.Certificates) == 0 {
cert, err := tls.LoadX509KeyPair(t.certFile, t.keyFile)
if err != nil {
return err
}
c.TlsConfig.Certificates = []tls.Certificate{cert}
}
return nil
case "tls_root_ca_file":
filename, ok := value.(string)
if !ok {
return fmt.Errorf("ERROR: %v is not a string", value)
}
tlsCertPool := x509.NewCertPool()
caCertFile, err := ioutil.ReadFile(filename)
if err != nil {
return fmt.Errorf("ERROR: failed to read custom Certificate Authority file %s", err)
}
if !tlsCertPool.AppendCertsFromPEM(caCertFile) {
return fmt.Errorf("ERROR: failed to append certificates from Certificate Authority file")
}
c.TlsConfig.RootCAs = tlsCertPool
return nil
case "tls_insecure_skip_verify":
fieldVal := val.FieldByName("InsecureSkipVerify")
dest := unsafeValueOf(fieldVal)
coercedVal, err := coerce(value, fieldVal.Type())
if err != nil {
return fmt.Errorf("failed to coerce option %s (%v) - %s",
option, value, err)
}
dest.Set(coercedVal)
return nil
case "tls_min_version":
version, ok := value.(string)
if !ok {
return fmt.Errorf("ERROR: %v is not a string", value)
}
switch version {
case "ssl3.0":
c.TlsConfig.MinVersion = tls.VersionSSL30
case "tls1.0":
c.TlsConfig.MinVersion = tls.VersionTLS10
case "tls1.1":
c.TlsConfig.MinVersion = tls.VersionTLS11
case "tls1.2":
c.TlsConfig.MinVersion = tls.VersionTLS12
default:
return fmt.Errorf("ERROR: %v is not a tls version", value)
}
return nil
}
return fmt.Errorf("unknown option %s", option)
}
func (t *tlsConfig) Validate(c *Config) error {
return nil
}
// because Config contains private structs we can't use reflect.Value
// directly, instead we need to "unsafely" address the variable
func unsafeValueOf(val reflect.Value) reflect.Value {
uptr := unsafe.Pointer(val.UnsafeAddr())
return reflect.NewAt(val.Type(), uptr).Elem()
}
func valueCompare(v1 reflect.Value, v2 reflect.Value) int {
switch v1.Type().String() {
case "int", "int16", "int32", "int64":
if v1.Int() > v2.Int() {
return 1
} else if v1.Int() < v2.Int() {
return -1
}
return 0
case "uint", "uint16", "uint32", "uint64":
if v1.Uint() > v2.Uint() {
return 1
} else if v1.Uint() < v2.Uint() {
return -1
}
return 0
case "float32", "float64":
if v1.Float() > v2.Float() {
return 1
} else if v1.Float() < v2.Float() {
return -1
}
return 0
case "time.Duration":
if v1.Interface().(time.Duration) > v2.Interface().(time.Duration) {
return 1
} else if v1.Interface().(time.Duration) < v2.Interface().(time.Duration) {
return -1
}
return 0
}
panic("impossible")
}
func coerce(v interface{}, typ reflect.Type) (reflect.Value, error) {
var err error
if typ.Kind() == reflect.Ptr {
return reflect.ValueOf(v), nil
}
switch typ.String() {
case "string":
v, err = coerceString(v)
case "int", "int16", "int32", "int64":
v, err = coerceInt64(v)
case "uint", "uint16", "uint32", "uint64":
v, err = coerceUint64(v)
case "float32", "float64":
v, err = coerceFloat64(v)
case "bool":
v, err = coerceBool(v)
case "time.Duration":
v, err = coerceDuration(v)
case "net.Addr":
v, err = coerceAddr(v)
case "nsq.BackoffStrategy":
v, err = coerceBackoffStrategy(v)
default:
v = nil
err = fmt.Errorf("invalid type %s", typ.String())
}
return valueTypeCoerce(v, typ), err
}
func valueTypeCoerce(v interface{}, typ reflect.Type) reflect.Value {
val := reflect.ValueOf(v)
if reflect.TypeOf(v) == typ {
return val
}
tval := reflect.New(typ).Elem()
switch typ.String() {
case "int", "int16", "int32", "int64":
tval.SetInt(val.Int())
case "uint", "uint16", "uint32", "uint64":
tval.SetUint(val.Uint())
case "float32", "float64":
tval.SetFloat(val.Float())
default:
tval.Set(val)
}
return tval
}
func coerceString(v interface{}) (string, error) {
switch v := v.(type) {
case string:
return v, nil
case int, int16, int32, int64, uint, uint16, uint32, uint64:
return fmt.Sprintf("%d", v), nil
case float32, float64:
return fmt.Sprintf("%f", v), nil
}
return fmt.Sprintf("%s", v), nil
}
func coerceDuration(v interface{}) (time.Duration, error) {
switch v := v.(type) {
case string:
return time.ParseDuration(v)
case int, int16, int32, int64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil
case uint, uint16, uint32, uint64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Uint()) * time.Millisecond, nil
case time.Duration:
return v, nil
}
return 0, errors.New("invalid value type")
}
func coerceAddr(v interface{}) (net.Addr, error) {
switch v := v.(type) {
case string:
return net.ResolveTCPAddr("tcp", v)
case net.Addr:
return v, nil
}
return nil, errors.New("invalid value type")
}
func coerceBackoffStrategy(v interface{}) (BackoffStrategy, error) {
switch v := v.(type) {
case string:
switch v {
case "", "exponential":
return &ExponentialStrategy{}, nil
case "full_jitter":
return &FullJitterStrategy{}, nil
}
case BackoffStrategy:
return v, nil
}
return nil, errors.New("invalid value type")
}
func coerceBool(v interface{}) (bool, error) {
switch v := v.(type) {
case bool:
return v, nil
case string:
return strconv.ParseBool(v)
case int, int16, int32, int64:
return reflect.ValueOf(v).Int() != 0, nil
case uint, uint16, uint32, uint64:
return reflect.ValueOf(v).Uint() != 0, nil
}
return false, errors.New("invalid value type")
}
func coerceFloat64(v interface{}) (float64, error) {
switch v := v.(type) {
case string:
return strconv.ParseFloat(v, 64)
case int, int16, int32, int64:
return float64(reflect.ValueOf(v).Int()), nil
case uint, uint16, uint32, uint64:
return float64(reflect.ValueOf(v).Uint()), nil
case float32:
return float64(v), nil
case float64:
return v, nil
}
return 0, errors.New("invalid value type")
}
func coerceInt64(v interface{}) (int64, error) {
switch v := v.(type) {
case string:
return strconv.ParseInt(v, 10, 64)
case int, int16, int32, int64:
return reflect.ValueOf(v).Int(), nil
case uint, uint16, uint32, uint64:
return int64(reflect.ValueOf(v).Uint()), nil
}
return 0, errors.New("invalid value type")
}
func coerceUint64(v interface{}) (uint64, error) {
switch v := v.(type) {
case string:
return strconv.ParseUint(v, 10, 64)
case int, int16, int32, int64:
return uint64(reflect.ValueOf(v).Int()), nil
case uint, uint16, uint32, uint64:
return reflect.ValueOf(v).Uint(), nil
}
return 0, errors.New("invalid value type")
}