Persist offline mqtt events in the queueDir and replay (#7037)

This commit is contained in:
Praveen raj Mani
2019-01-14 12:39:00 +05:30
committed by Harshavardhana
parent 8757c963ba
commit 6571641735
27 changed files with 1507 additions and 265 deletions

View File

@@ -1,3 +1,7 @@
[![GoDoc](https://godoc.org/github.com/eclipse/paho.mqtt.golang?status.svg)](https://godoc.org/github.com/eclipse/paho.mqtt.golang)
[![Go Report Card](https://goreportcard.com/badge/github.com/eclipse/paho.mqtt.golang)](https://goreportcard.com/report/github.com/eclipse/paho.mqtt.golang)
Eclipse Paho MQTT Go client
===========================
@@ -18,11 +22,12 @@ This client is designed to work with the standard Go tools, so installation is a
go get github.com/eclipse/paho.mqtt.golang
```
The client depends on Google's [websockets](https://godoc.org/golang.org/x/net/websocket) package,
also easily installed with the command:
The client depends on Google's [websockets](https://godoc.org/golang.org/x/net/websocket) and [proxy](https://godoc.org/golang.org/x/net/proxy) package,
also easily installed with the commands:
```
go get golang.org/x/net/websocket
go get golang.org/x/net/proxy
```

View File

@@ -12,6 +12,8 @@
* Mike Robertson
*/
// Portions copyright © 2018 TIBCO Software Inc.
// Package mqtt provides an MQTT v3.1.1 client library.
package mqtt
@@ -19,16 +21,16 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/eclipse/paho.mqtt.golang/packets"
)
type connStatus uint
const (
disconnected connStatus = iota
disconnected uint32 = iota
connecting
reconnecting
connected
@@ -53,18 +55,50 @@ const (
// Numerous connection options may be specified by configuring a
// and then supplying a ClientOptions type.
type Client interface {
// IsConnected returns a bool signifying whether
// the client is connected or not.
IsConnected() bool
// IsConnectionOpen return a bool signifying wether the client has an active
// connection to mqtt broker, i.e not in disconnected or reconnect mode
IsConnectionOpen() bool
// Connect will create a connection to the message broker, by default
// it will attempt to connect at v3.1.1 and auto retry at v3.1 if that
// fails
Connect() Token
// Disconnect will end the connection with the server, but not before waiting
// the specified number of milliseconds to wait for existing work to be
// completed.
Disconnect(quiesce uint)
// Publish will publish a message with the specified QoS and content
// to the specified topic.
// Returns a token to track delivery of the message to the broker
Publish(topic string, qos byte, retained bool, payload interface{}) Token
// Subscribe starts a new subscription. Provide a MessageHandler to be executed when
// a message is published on the topic provided, or nil for the default handler
Subscribe(topic string, qos byte, callback MessageHandler) Token
// SubscribeMultiple starts a new subscription for multiple topics. Provide a MessageHandler to
// be executed when a message is published on one of the topics provided, or nil for the
// default handler
SubscribeMultiple(filters map[string]byte, callback MessageHandler) Token
// Unsubscribe will end the subscription from each of the topics provided.
// Messages published to those topics from other clients will no longer be
// received.
Unsubscribe(topics ...string) Token
// AddRoute allows you to add a handler for messages on a specific topic
// without making a subscription. For example having a different handler
// for parts of a wildcard subscription
AddRoute(topic string, callback MessageHandler)
// OptionsReader returns a ClientOptionsReader which is a copy of the clientoptions
// in use by the client.
OptionsReader() ClientOptionsReader
}
// client implements the Client interface
type client struct {
lastSent int64
lastReceived int64
pingOutstanding int32
status uint32
sync.RWMutex
messageIds
conn net.Conn
@@ -78,10 +112,6 @@ type client struct {
stop chan struct{}
persist Store
options ClientOptions
pingResp chan struct{}
packetResp chan struct{}
keepaliveReset chan struct{}
status connStatus
workers sync.WaitGroup
}
@@ -99,21 +129,26 @@ func NewClient(o *ClientOptions) Client {
switch c.options.ProtocolVersion {
case 3, 4:
c.options.protocolVersionExplicit = true
case 0x83, 0x84:
c.options.protocolVersionExplicit = true
default:
c.options.ProtocolVersion = 4
c.options.protocolVersionExplicit = false
}
c.persist = c.options.Store
c.status = disconnected
c.messageIds = messageIds{index: make(map[uint16]Token)}
c.messageIds = messageIds{index: make(map[uint16]tokenCompletor)}
c.msgRouter, c.stopRouter = newRouter()
c.msgRouter.setDefaultHandler(c.options.DefaultPublishHander)
c.msgRouter.setDefaultHandler(c.options.DefaultPublishHandler)
if !c.options.AutoReconnect {
c.options.MessageChannelDepth = 0
}
return c
}
// AddRoute allows you to add a handler for messages on a specific topic
// without making a subscription. For example having a different handler
// for parts of a wildcard subscription
func (c *client) AddRoute(topic string, callback MessageHandler) {
if callback != nil {
c.msgRouter.addRoute(topic, callback)
@@ -125,54 +160,81 @@ func (c *client) AddRoute(topic string, callback MessageHandler) {
func (c *client) IsConnected() bool {
c.RLock()
defer c.RUnlock()
status := atomic.LoadUint32(&c.status)
switch {
case c.status == connected:
case status == connected:
return true
case c.options.AutoReconnect && c.status > disconnected:
case c.options.AutoReconnect && status > connecting:
return true
default:
return false
}
}
func (c *client) connectionStatus() connStatus {
// IsConnectionOpen return a bool signifying whether the client has an active
// connection to mqtt broker, i.e not in disconnected or reconnect mode
func (c *client) IsConnectionOpen() bool {
c.RLock()
defer c.RUnlock()
return c.status
status := atomic.LoadUint32(&c.status)
switch {
case status == connected:
return true
default:
return false
}
}
func (c *client) setConnected(status connStatus) {
func (c *client) connectionStatus() uint32 {
c.RLock()
defer c.RUnlock()
status := atomic.LoadUint32(&c.status)
return status
}
func (c *client) setConnected(status uint32) {
c.Lock()
defer c.Unlock()
c.status = status
atomic.StoreUint32(&c.status, uint32(status))
}
//ErrNotConnected is the error returned from function calls that are
//made when the client is not connected to a broker
var ErrNotConnected = errors.New("Not Connected")
// Connect will create a connection to the message broker
// If clean session is false, then a slice will
// be returned containing Receipts for all messages
// that were in-flight at the last disconnect.
// If clean session is true, then any existing client
// state will be removed.
// Connect will create a connection to the message broker, by default
// it will attempt to connect at v3.1.1 and auto retry at v3.1 if that
// fails
func (c *client) Connect() Token {
var err error
t := newToken(packets.Connect).(*ConnectToken)
DEBUG.Println(CLI, "Connect()")
c.obound = make(chan *PacketAndToken, c.options.MessageChannelDepth)
c.oboundP = make(chan *PacketAndToken, c.options.MessageChannelDepth)
c.ibound = make(chan packets.ControlPacket)
go func() {
c.persist.Open()
c.setConnected(connecting)
c.errors = make(chan error, 1)
c.stop = make(chan struct{})
var rc byte
cm := newConnectMsgFromOptions(&c.options)
protocolVersion := c.options.ProtocolVersion
if len(c.options.Servers) == 0 {
t.setError(fmt.Errorf("No servers defined to connect to"))
return
}
for _, broker := range c.options.Servers {
cm := newConnectMsgFromOptions(&c.options, broker)
c.options.ProtocolVersion = protocolVersion
CONN:
DEBUG.Println(CLI, "about to write new connect msg")
c.conn, err = openConnection(broker, &c.options.TLSConfig, c.options.ConnectTimeout)
c.conn, err = openConnection(broker, c.options.TLSConfig, c.options.ConnectTimeout, c.options.HTTPHeaders)
if err == nil {
DEBUG.Println(CLI, "socket connected to broker")
switch c.options.ProtocolVersion {
@@ -180,6 +242,14 @@ func (c *client) Connect() Token {
DEBUG.Println(CLI, "Using MQTT 3.1 protocol")
cm.ProtocolName = "MQIsdp"
cm.ProtocolVersion = 3
case 0x83:
DEBUG.Println(CLI, "Using MQTT 3.1b protocol")
cm.ProtocolName = "MQIsdp"
cm.ProtocolVersion = 0x83
case 0x84:
DEBUG.Println(CLI, "Using MQTT 3.1.1b protocol")
cm.ProtocolName = "MQTT"
cm.ProtocolVersion = 0x84
default:
DEBUG.Println(CLI, "Using MQTT 3.1.1 protocol")
c.options.ProtocolVersion = 4
@@ -188,7 +258,7 @@ func (c *client) Connect() Token {
}
cm.Write(c.conn)
rc = c.connect()
rc, t.sessionPresent = c.connect()
if rc != packets.Accepted {
if c.conn != nil {
c.conn.Close()
@@ -215,32 +285,27 @@ func (c *client) Connect() Token {
if c.conn == nil {
ERROR.Println(CLI, "Failed to connect to a broker")
t.returnCode = rc
if rc != packets.ErrNetworkError {
t.err = packets.ConnErrors[rc]
} else {
t.err = fmt.Errorf("%s : %s", packets.ConnErrors[rc], err)
}
c.setConnected(disconnected)
c.persist.Close()
t.flowComplete()
t.returnCode = rc
if rc != packets.ErrNetworkError {
t.setError(packets.ConnErrors[rc])
} else {
t.setError(fmt.Errorf("%s : %s", packets.ConnErrors[rc], err))
}
return
}
c.options.protocolVersionExplicit = true
if c.options.KeepAlive != 0 {
atomic.StoreInt32(&c.pingOutstanding, 0)
atomic.StoreInt64(&c.lastReceived, time.Now().Unix())
atomic.StoreInt64(&c.lastSent, time.Now().Unix())
c.workers.Add(1)
go keepalive(c)
}
c.obound = make(chan *PacketAndToken, c.options.MessageChannelDepth)
c.oboundP = make(chan *PacketAndToken, c.options.MessageChannelDepth)
c.ibound = make(chan packets.ControlPacket)
c.errors = make(chan error, 1)
c.stop = make(chan struct{})
c.pingResp = make(chan struct{}, 1)
c.packetResp = make(chan struct{}, 1)
c.keepaliveReset = make(chan struct{}, 1)
c.incomingPubChan = make(chan *packets.PublishPacket, c.options.MessageChannelDepth)
c.msgRouter.matchAndDispatch(c.incomingPubChan, c.options.Order, c)
@@ -250,22 +315,19 @@ func (c *client) Connect() Token {
go c.options.OnConnect(c)
}
// Take care of any messages in the store
//var leftovers []Receipt
if c.options.CleanSession == false {
//leftovers = c.resume()
} else {
c.persist.Reset()
}
c.workers.Add(4)
go errorWatch(c)
// Do not start incoming until resume has completed
c.workers.Add(3)
go alllogic(c)
go outgoing(c)
go incoming(c)
// Take care of any messages in the store
if c.options.CleanSession == false {
c.resume(c.options.ResumeSubs)
} else {
c.persist.Reset()
}
DEBUG.Println(CLI, "exit startClient")
t.flowComplete()
}()
@@ -283,28 +345,33 @@ func (c *client) reconnect() {
)
for rc != 0 && c.status != disconnected {
cm := newConnectMsgFromOptions(&c.options)
for _, broker := range c.options.Servers {
CONN:
cm := newConnectMsgFromOptions(&c.options, broker)
DEBUG.Println(CLI, "about to write new connect msg")
c.conn, err = openConnection(broker, &c.options.TLSConfig, c.options.ConnectTimeout)
c.conn, err = openConnection(broker, c.options.TLSConfig, c.options.ConnectTimeout, c.options.HTTPHeaders)
if err == nil {
DEBUG.Println(CLI, "socket connected to broker")
switch c.options.ProtocolVersion {
case 0x83:
DEBUG.Println(CLI, "Using MQTT 3.1b protocol")
cm.ProtocolName = "MQIsdp"
cm.ProtocolVersion = 0x83
case 0x84:
DEBUG.Println(CLI, "Using MQTT 3.1.1b protocol")
cm.ProtocolName = "MQTT"
cm.ProtocolVersion = 0x84
case 3:
DEBUG.Println(CLI, "Using MQTT 3.1 protocol")
cm.ProtocolName = "MQIsdp"
cm.ProtocolVersion = 3
default:
DEBUG.Println(CLI, "Using MQTT 3.1.1 protocol")
c.options.ProtocolVersion = 4
cm.ProtocolName = "MQTT"
cm.ProtocolVersion = 4
}
cm.Write(c.conn)
rc = c.connect()
rc, _ = c.connect()
if rc != packets.Accepted {
c.conn.Close()
c.conn = nil
@@ -313,11 +380,6 @@ func (c *client) reconnect() {
ERROR.Println(CLI, "Connecting to", broker, "CONNACK was not Accepted, but rather", packets.ConnackReturnCodes[rc])
continue
}
if c.options.ProtocolVersion == 4 {
DEBUG.Println(CLI, "Trying reconnect using MQTT 3.1 protocol")
c.options.ProtocolVersion = 3
goto CONN
}
}
break
} else {
@@ -339,64 +401,69 @@ func (c *client) reconnect() {
}
}
// Disconnect() must have been called while we were trying to reconnect.
if c.status == disconnected {
if c.connectionStatus() == disconnected {
DEBUG.Println(CLI, "Client moved to disconnected state while reconnecting, abandoning reconnect")
return
}
c.stop = make(chan struct{})
if c.options.KeepAlive != 0 {
atomic.StoreInt32(&c.pingOutstanding, 0)
atomic.StoreInt64(&c.lastReceived, time.Now().Unix())
atomic.StoreInt64(&c.lastSent, time.Now().Unix())
c.workers.Add(1)
go keepalive(c)
}
c.stop = make(chan struct{})
c.setConnected(connected)
DEBUG.Println(CLI, "client is reconnected")
if c.options.OnConnect != nil {
go c.options.OnConnect(c)
}
c.workers.Add(4)
go errorWatch(c)
c.workers.Add(3)
go alllogic(c)
go outgoing(c)
go incoming(c)
c.resume(false)
}
// This function is only used for receiving a connack
// when the connection is first started.
// This prevents receiving incoming data while resume
// is in progress if clean session is false.
func (c *client) connect() byte {
func (c *client) connect() (byte, bool) {
DEBUG.Println(NET, "connect started")
ca, err := packets.ReadPacket(c.conn)
if err != nil {
ERROR.Println(NET, "connect got error", err)
return packets.ErrNetworkError
return packets.ErrNetworkError, false
}
if ca == nil {
ERROR.Println(NET, "received nil packet")
return packets.ErrNetworkError
return packets.ErrNetworkError, false
}
msg, ok := ca.(*packets.ConnackPacket)
if !ok {
ERROR.Println(NET, "received msg that was not CONNACK")
return packets.ErrNetworkError
return packets.ErrNetworkError, false
}
DEBUG.Println(NET, "received connack")
return msg.ReturnCode
return msg.ReturnCode, msg.SessionPresent
}
// Disconnect will end the connection with the server, but not before waiting
// the specified number of milliseconds to wait for existing work to be
// completed.
func (c *client) Disconnect(quiesce uint) {
if c.status == connected {
status := atomic.LoadUint32(&c.status)
if status == connected {
DEBUG.Println(CLI, "disconnecting")
c.setConnected(disconnected)
@@ -409,6 +476,7 @@ func (c *client) Disconnect(quiesce uint) {
} else {
WARN.Println(CLI, "Disconnect() called but not connected (disconnected/reconnecting)")
c.setConnected(disconnected)
return
}
c.disconnect()
@@ -456,7 +524,9 @@ func (c *client) closeStop() {
case <-c.stop:
DEBUG.Println("In disconnect and stop channel is already closed")
default:
close(c.stop)
if c.stop != nil {
close(c.stop)
}
}
}
@@ -472,6 +542,7 @@ func (c *client) disconnect() {
c.closeStop()
c.closeConn()
c.workers.Wait()
c.messageIds.cleanUp()
close(c.stopRouter)
DEBUG.Println(CLI, "disconnected")
c.persist.Close()
@@ -507,13 +578,17 @@ func (c *client) Publish(topic string, qos byte, retained bool, payload interfac
return token
}
DEBUG.Println(CLI, "sending publish message, topic:", topic)
if pub.Qos != 0 && pub.MessageID == 0 {
pub.MessageID = c.getID(token)
token.messageID = pub.MessageID
}
persistOutbound(c.persist, pub)
c.obound <- &PacketAndToken{p: pub, t: token}
if c.connectionStatus() == reconnecting {
DEBUG.Println(CLI, "storing publish message (reconnecting), topic:", topic)
} else {
DEBUG.Println(CLI, "sending publish message, topic:", topic)
c.obound <- &PacketAndToken{p: pub, t: token}
}
return token
}
@@ -536,6 +611,10 @@ func (c *client) Subscribe(topic string, qos byte, callback MessageHandler) Toke
sub.Qoss = append(sub.Qoss, qos)
DEBUG.Println(CLI, sub.String())
if strings.HasPrefix(topic, "$share") {
topic = strings.Join(strings.Split(topic, "/")[2:], "/")
}
if callback != nil {
c.msgRouter.addRoute(topic, callback)
}
@@ -575,6 +654,61 @@ func (c *client) SubscribeMultiple(filters map[string]byte, callback MessageHand
return token
}
// Load all stored messages and resend them
// Call this to ensure QOS > 1,2 even after an application crash
func (c *client) resume(subscription bool) {
storedKeys := c.persist.All()
for _, key := range storedKeys {
packet := c.persist.Get(key)
details := packet.Details()
if isKeyOutbound(key) {
switch packet.(type) {
case *packets.SubscribePacket:
if subscription {
DEBUG.Println(STR, fmt.Sprintf("loaded pending subscribe (%d)", details.MessageID))
token := newToken(packets.Subscribe).(*SubscribeToken)
c.oboundP <- &PacketAndToken{p: packet, t: token}
}
case *packets.UnsubscribePacket:
if subscription {
DEBUG.Println(STR, fmt.Sprintf("loaded pending unsubscribe (%d)", details.MessageID))
token := newToken(packets.Unsubscribe).(*UnsubscribeToken)
c.oboundP <- &PacketAndToken{p: packet, t: token}
}
case *packets.PubrelPacket:
DEBUG.Println(STR, fmt.Sprintf("loaded pending pubrel (%d)", details.MessageID))
select {
case c.oboundP <- &PacketAndToken{p: packet, t: nil}:
case <-c.stop:
}
case *packets.PublishPacket:
token := newToken(packets.Publish).(*PublishToken)
token.messageID = details.MessageID
c.claimID(token, details.MessageID)
DEBUG.Println(STR, fmt.Sprintf("loaded pending publish (%d)", details.MessageID))
DEBUG.Println(STR, details)
c.obound <- &PacketAndToken{p: packet, t: token}
default:
ERROR.Println(STR, "invalid message type in store (discarded)")
c.persist.Del(key)
}
} else {
switch packet.(type) {
case *packets.PubrelPacket, *packets.PublishPacket:
DEBUG.Println(STR, fmt.Sprintf("loaded pending incomming (%d)", details.MessageID))
select {
case c.ibound <- packet:
case <-c.stop:
}
default:
ERROR.Println(STR, "invalid message type in store (discarded)")
c.persist.Del(key)
}
}
}
}
// Unsubscribe will end the subscription from each of the topics provided.
// Messages published to those topics from other clients will no longer be
// received.
@@ -599,6 +733,8 @@ func (c *client) Unsubscribe(topics ...string) Token {
return token
}
// OptionsReader returns a ClientOptionsReader which is a copy of the clientoptions
// in use by the client.
func (c *client) OptionsReader() ClientOptionsReader {
r := ClientOptionsReader{options: &c.options}
return r

View File

@@ -18,6 +18,7 @@ import (
"io/ioutil"
"os"
"path"
"sort"
"sync"
"github.com/eclipse/paho.mqtt.golang/packets"
@@ -150,13 +151,18 @@ func (store *FileStore) Reset() {
// lockless
func (store *FileStore) all() []string {
var err error
var keys []string
var files fileInfos
if !store.opened {
ERROR.Println(STR, "Trying to use file store, but not open")
return nil
}
keys := []string{}
files, rderr := ioutil.ReadDir(store.directory)
chkerr(rderr)
files, err = ioutil.ReadDir(store.directory)
chkerr(err)
sort.Sort(files)
for _, f := range files {
DEBUG.Println(STR, "file in All():", f.Name())
name := f.Name()
@@ -233,3 +239,17 @@ func exists(file string) bool {
}
return true
}
type fileInfos []os.FileInfo
func (f fileInfos) Len() int {
return len(f)
}
func (f fileInfos) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func (f fileInfos) Less(i, j int) bool {
return f[i].ModTime().Before(f[j].ModTime())
}

View File

@@ -109,7 +109,7 @@ func (store *MemoryStore) Del(key string) {
if m == nil {
WARN.Println(STR, "memorystore del: message", mid, "not found")
} else {
store.messages[key] = nil
delete(store.messages, key)
DEBUG.Println(STR, "memorystore del: message", mid, "was deleted")
}
}

View File

@@ -15,7 +15,10 @@
package mqtt
import (
"net/url"
"github.com/eclipse/paho.mqtt.golang/packets"
"sync"
)
// Message defines the externals that a message implementation must support
@@ -28,6 +31,7 @@ type Message interface {
Topic() string
MessageID() uint16
Payload() []byte
Ack()
}
type message struct {
@@ -37,6 +41,8 @@ type message struct {
topic string
messageID uint16
payload []byte
once sync.Once
ack func()
}
func (m *message) Duplicate() bool {
@@ -63,7 +69,11 @@ func (m *message) Payload() []byte {
return m.payload
}
func messageFromPublish(p *packets.PublishPacket) Message {
func (m *message) Ack() {
m.once.Do(m.ack)
}
func messageFromPublish(p *packets.PublishPacket, ack func()) Message {
return &message{
duplicate: p.Dup,
qos: p.Qos,
@@ -71,10 +81,11 @@ func messageFromPublish(p *packets.PublishPacket) Message {
topic: p.TopicName,
messageID: p.MessageID,
payload: p.Payload,
ack: ack,
}
}
func newConnectMsgFromOptions(options *ClientOptions) *packets.ConnectPacket {
func newConnectMsgFromOptions(options *ClientOptions, broker *url.URL) *packets.ConnectPacket {
m := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
m.CleanSession = options.CleanSession
@@ -88,17 +99,29 @@ func newConnectMsgFromOptions(options *ClientOptions) *packets.ConnectPacket {
m.WillMessage = options.WillPayload
}
if options.Username != "" {
username := options.Username
password := options.Password
if broker.User != nil {
username = broker.User.Username()
if pwd, ok := broker.User.Password(); ok {
password = pwd
}
}
if options.CredentialsProvider != nil {
username, password = options.CredentialsProvider()
}
if username != "" {
m.UsernameFlag = true
m.Username = options.Username
m.Username = username
//mustn't have password without user as well
if options.Password != "" {
if password != "" {
m.PasswordFlag = true
m.Password = []byte(options.Password)
m.Password = []byte(password)
}
}
m.Keepalive = uint16(options.KeepAlive.Seconds())
m.Keepalive = uint16(options.KeepAlive)
return m
}

View File

@@ -17,6 +17,7 @@ package mqtt
import (
"fmt"
"sync"
"time"
)
// MId is 16 bit message id as specified by the MQTT spec.
@@ -26,7 +27,7 @@ type MId uint16
type messageIds struct {
sync.RWMutex
index map[uint16]Token
index map[uint16]tokenCompletor
}
const (
@@ -44,11 +45,14 @@ func (mids *messageIds) cleanUp() {
t.err = fmt.Errorf("Connection lost before Subscribe completed")
case *UnsubscribeToken:
t.err = fmt.Errorf("Connection lost before Unsubscribe completed")
case nil:
continue
}
token.flowComplete()
}
mids.index = make(map[uint16]Token)
mids.index = make(map[uint16]tokenCompletor)
mids.Unlock()
DEBUG.Println(MID, "cleaned up")
}
func (mids *messageIds) freeID(id uint16) {
@@ -57,7 +61,19 @@ func (mids *messageIds) freeID(id uint16) {
mids.Unlock()
}
func (mids *messageIds) getID(t Token) uint16 {
func (mids *messageIds) claimID(token tokenCompletor, id uint16) {
mids.Lock()
defer mids.Unlock()
if _, ok := mids.index[id]; !ok {
mids.index[id] = token
} else {
old := mids.index[id]
old.flowComplete()
mids.index[id] = token
}
}
func (mids *messageIds) getID(t tokenCompletor) uint16 {
mids.Lock()
defer mids.Unlock()
for i := midMin; i < midMax; i++ {
@@ -69,11 +85,33 @@ func (mids *messageIds) getID(t Token) uint16 {
return 0
}
func (mids *messageIds) getToken(id uint16) Token {
func (mids *messageIds) getToken(id uint16) tokenCompletor {
mids.RLock()
defer mids.RUnlock()
if token, ok := mids.index[id]; ok {
return token
}
return &DummyToken{id: id}
}
type DummyToken struct {
id uint16
}
func (d *DummyToken) Wait() bool {
return true
}
func (d *DummyToken) WaitTimeout(t time.Duration) bool {
return true
}
func (d *DummyToken) flowComplete() {
ERROR.Printf("A lookup for token %d returned nil\n", d.id)
}
func (d *DummyToken) Error() error {
return nil
}
func (d *DummyToken) setError(e error) {}

View File

@@ -19,11 +19,15 @@ import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"reflect"
"sync/atomic"
"time"
"github.com/eclipse/paho.mqtt.golang/packets"
"golang.org/x/net/proxy"
"golang.org/x/net/websocket"
)
@@ -34,10 +38,14 @@ func signalError(c chan<- error, err error) {
}
}
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.Conn, error) {
func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration, headers http.Header) (net.Conn, error) {
switch uri.Scheme {
case "ws":
conn, err := websocket.Dial(uri.String(), "mqtt", fmt.Sprintf("http://%s", uri.Host))
config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("http://%s", uri.Host))
config.Protocol = []string{"mqtt"}
config.Header = headers
config.Dialer = &net.Dialer{Timeout: timeout}
conn, err := websocket.DialConfig(config)
if err != nil {
return nil, err
}
@@ -47,6 +55,8 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.
config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("https://%s", uri.Host))
config.Protocol = []string{"mqtt"}
config.TlsConfig = tlsc
config.Header = headers
config.Dialer = &net.Dialer{Timeout: timeout}
conn, err := websocket.DialConfig(config)
if err != nil {
return nil, err
@@ -54,7 +64,23 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.
conn.PayloadType = websocket.BinaryFrame
return conn, err
case "tcp":
conn, err := net.DialTimeout("tcp", uri.Host, timeout)
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := net.DialTimeout("tcp", uri.Host, timeout)
if err != nil {
return nil, err
}
return conn, nil
}
proxyDialer := proxy.FromEnvironment()
conn, err := proxyDialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
}
return conn, nil
case "unix":
conn, err := net.DialTimeout("unix", uri.Host, timeout)
if err != nil {
return nil, err
}
@@ -64,11 +90,30 @@ func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.
case "tls":
fallthrough
case "tcps":
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
if err != nil {
return nil, err
}
return conn, nil
}
proxyDialer := proxy.FromEnvironment()
conn, err := proxyDialer.Dial("tcp", uri.Host)
if err != nil {
return nil, err
}
return conn, nil
tlsConn := tls.Client(conn, tlsc)
err = tlsConn.Handshake()
if err != nil {
conn.Close()
return nil, err
}
return tlsConn, nil
}
return nil, errors.New("Unknown protocol")
}
@@ -92,7 +137,7 @@ func incoming(c *client) {
case c.ibound <- cp:
// Notify keepalive logic that we recently received a packet
if c.options.KeepAlive != 0 {
c.packetResp <- struct{}{}
atomic.StoreInt64(&c.lastReceived, time.Now().Unix())
}
case <-c.stop:
// This avoids a deadlock should a message arrive while shutting down.
@@ -136,6 +181,7 @@ func outgoing(c *client) {
if err := msg.Write(c.conn); err != nil {
ERROR.Println(NET, "outgoing stopped with error", err)
pub.t.setError(err)
signalError(c.errors, err)
return
}
@@ -160,6 +206,9 @@ func outgoing(c *client) {
DEBUG.Println(NET, "obound priority msg to write, type", reflect.TypeOf(msg.p))
if err := msg.p.Write(c.conn); err != nil {
ERROR.Println(NET, "outgoing stopped with error", err)
if msg.t != nil {
msg.t.setError(err)
}
signalError(c.errors, err)
return
}
@@ -172,11 +221,7 @@ func outgoing(c *client) {
}
// Reset ping timer after sending control packet.
if c.options.KeepAlive != 0 {
select {
case c.keepaliveReset <- struct{}{}:
default:
DEBUG.Println(NET, "couldn't send keepalive signal in outbound as channel full")
}
atomic.StoreInt64(&c.lastSent, time.Now().Unix())
}
}
}
@@ -199,20 +244,22 @@ func alllogic(c *client) {
switch m := msg.(type) {
case *packets.PingrespPacket:
DEBUG.Println(NET, "received pingresp")
c.pingResp <- struct{}{}
atomic.StoreInt32(&c.pingOutstanding, 0)
case *packets.SubackPacket:
DEBUG.Println(NET, "received suback, id:", m.MessageID)
token := c.getToken(m.MessageID).(*SubscribeToken)
DEBUG.Println(NET, "granted qoss", m.ReturnCodes)
for i, qos := range m.ReturnCodes {
token.subResult[token.subs[i]] = qos
token := c.getToken(m.MessageID)
switch t := token.(type) {
case *SubscribeToken:
DEBUG.Println(NET, "granted qoss", m.ReturnCodes)
for i, qos := range m.ReturnCodes {
t.subResult[t.subs[i]] = qos
}
}
token.flowComplete()
c.freeID(m.MessageID)
case *packets.UnsubackPacket:
DEBUG.Println(NET, "received unsuback, id:", m.MessageID)
token := c.getToken(m.MessageID).(*UnsubscribeToken)
token.flowComplete()
c.getToken(m.MessageID).flowComplete()
c.freeID(m.MessageID)
case *packets.PublishPacket:
DEBUG.Println(NET, "received publish, msgId:", m.MessageID)
@@ -221,21 +268,14 @@ func alllogic(c *client) {
case 2:
c.incomingPubChan <- m
DEBUG.Println(NET, "done putting msg on incomingPubChan")
pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
pr.MessageID = m.MessageID
DEBUG.Println(NET, "putting pubrec msg on obound")
c.oboundP <- &PacketAndToken{p: pr, t: nil}
DEBUG.Println(NET, "done putting pubrec msg on obound")
case 1:
c.incomingPubChan <- m
DEBUG.Println(NET, "done putting msg on incomingPubChan")
pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
pa.MessageID = m.MessageID
DEBUG.Println(NET, "putting puback msg on obound")
c.oboundP <- &PacketAndToken{p: pa, t: nil}
DEBUG.Println(NET, "done putting puback msg on obound")
case 0:
c.incomingPubChan <- m
select {
case c.incomingPubChan <- m:
case <-c.stop:
}
DEBUG.Println(NET, "done putting msg on incomingPubChan")
}
case *packets.PubackPacket:
@@ -250,15 +290,16 @@ func alllogic(c *client) {
prel.MessageID = m.MessageID
select {
case c.oboundP <- &PacketAndToken{p: prel, t: nil}:
case <-time.After(time.Second):
case <-c.stop:
}
case *packets.PubrelPacket:
DEBUG.Println(NET, "received pubrel, id:", m.MessageID)
pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
pc.MessageID = m.MessageID
persistOutbound(c.persist, pc)
select {
case c.oboundP <- &PacketAndToken{p: pc, t: nil}:
case <-time.After(time.Second):
case <-c.stop:
}
case *packets.PubcompPacket:
DEBUG.Println(NET, "received pubcomp, id:", m.MessageID)
@@ -272,14 +313,43 @@ func alllogic(c *client) {
}
}
func (c *client) ackFunc(packet *packets.PublishPacket) func() {
return func() {
switch packet.Qos {
case 2:
pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
pr.MessageID = packet.MessageID
DEBUG.Println(NET, "putting pubrec msg on obound")
select {
case c.oboundP <- &PacketAndToken{p: pr, t: nil}:
case <-c.stop:
}
DEBUG.Println(NET, "done putting pubrec msg on obound")
case 1:
pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
pa.MessageID = packet.MessageID
DEBUG.Println(NET, "putting puback msg on obound")
persistOutbound(c.persist, pa)
select {
case c.oboundP <- &PacketAndToken{p: pa, t: nil}:
case <-c.stop:
}
DEBUG.Println(NET, "done putting puback msg on obound")
case 0:
// do nothing, since there is no need to send an ack packet back
}
}
}
func errorWatch(c *client) {
defer c.workers.Done()
select {
case <-c.stop:
WARN.Println(NET, "errorWatch stopped")
return
case err := <-c.errors:
ERROR.Println(NET, "error triggered, stopping")
c.internalConnLost(err)
go c.internalConnLost(err)
return
}
}

View File

@@ -12,14 +12,22 @@
* Mike Robertson
*/
// Portions copyright © 2018 TIBCO Software Inc.
package mqtt
import (
"crypto/tls"
"net/http"
"net/url"
"strings"
"time"
)
// CredentialsProvider allows the username and password to be updated
// before reconnecting. It should return the current username and password.
type CredentialsProvider func() (username string, password string)
// MessageHandler is a callback type which can be set to be
// executed upon the arrival of messages published to topics
// to which the client is subscribed.
@@ -42,6 +50,7 @@ type ClientOptions struct {
ClientID string
Username string
Password string
CredentialsProvider CredentialsProvider
CleanSession bool
Order bool
WillEnabled bool
@@ -51,18 +60,20 @@ type ClientOptions struct {
WillRetained bool
ProtocolVersion uint
protocolVersionExplicit bool
TLSConfig tls.Config
KeepAlive time.Duration
TLSConfig *tls.Config
KeepAlive int64
PingTimeout time.Duration
ConnectTimeout time.Duration
MaxReconnectInterval time.Duration
AutoReconnect bool
Store Store
DefaultPublishHander MessageHandler
DefaultPublishHandler MessageHandler
OnConnect OnConnectHandler
OnConnectionLost ConnectionLostHandler
WriteTimeout time.Duration
MessageChannelDepth uint
ResumeSubs bool
HTTPHeaders http.Header
}
// NewClientOptions will create a new ClientClientOptions type with some
@@ -89,8 +100,7 @@ func NewClientOptions() *ClientOptions {
WillRetained: false,
ProtocolVersion: 0,
protocolVersionExplicit: false,
TLSConfig: tls.Config{},
KeepAlive: 30 * time.Second,
KeepAlive: 30,
PingTimeout: 10 * time.Second,
ConnectTimeout: 30 * time.Second,
MaxReconnectInterval: 10 * time.Minute,
@@ -100,6 +110,8 @@ func NewClientOptions() *ClientOptions {
OnConnectionLost: DefaultConnectionLostHandler,
WriteTimeout: 0, // 0 represents timeout disabled
MessageChannelDepth: 100,
ResumeSubs: false,
HTTPHeaders: make(map[string][]string),
}
return o
}
@@ -108,11 +120,30 @@ func NewClientOptions() *ClientOptions {
// scheme://host:port
// Where "scheme" is one of "tcp", "ssl", or "ws", "host" is the ip-address (or hostname)
// and "port" is the port on which the broker is accepting connections.
//
// Default values for hostname is "127.0.0.1", for schema is "tcp://".
//
// An example broker URI would look like: tcp://foobar.com:1883
func (o *ClientOptions) AddBroker(server string) *ClientOptions {
brokerURI, err := url.Parse(server)
if err == nil {
o.Servers = append(o.Servers, brokerURI)
if len(server) > 0 && server[0] == ':' {
server = "127.0.0.1" + server
}
if !strings.Contains(server, "://") {
server = "tcp://" + server
}
brokerURI, err := url.Parse(server)
if err != nil {
ERROR.Println(CLI, "Failed to parse %q broker address: %s", server, err)
return o
}
o.Servers = append(o.Servers, brokerURI)
return o
}
// SetResumeSubs will enable resuming of stored (un)subscribe messages when connecting
// but not reconnecting if CleanSession is false. Otherwise these messages are discarded.
func (o *ClientOptions) SetResumeSubs(resume bool) *ClientOptions {
o.ResumeSubs = resume
return o
}
@@ -140,6 +171,15 @@ func (o *ClientOptions) SetPassword(p string) *ClientOptions {
return o
}
// SetCredentialsProvider will set a method to be called by this client when
// connecting to the MQTT broker that provide the current username and password.
// Note: without the use of SSL/TLS, this information will be sent
// in plaintext accross the wire.
func (o *ClientOptions) SetCredentialsProvider(p CredentialsProvider) *ClientOptions {
o.CredentialsProvider = p
return o
}
// SetCleanSession will set the "clean session" flag in the connect message
// when this client connects to an MQTT broker. By setting this flag, you are
// indicating that no messages saved by the broker for this client should be
@@ -164,7 +204,7 @@ func (o *ClientOptions) SetOrderMatters(order bool) *ClientOptions {
// to an MQTT broker. Please read the official Go documentation for more
// information.
func (o *ClientOptions) SetTLSConfig(t *tls.Config) *ClientOptions {
o.TLSConfig = *t
o.TLSConfig = t
return o
}
@@ -182,7 +222,7 @@ func (o *ClientOptions) SetStore(s Store) *ClientOptions {
// allow the client to know that a connection has not been lost with the
// server.
func (o *ClientOptions) SetKeepAlive(k time.Duration) *ClientOptions {
o.KeepAlive = k
o.KeepAlive = int64(k / time.Second)
return o
}
@@ -197,7 +237,7 @@ func (o *ClientOptions) SetPingTimeout(k time.Duration) *ClientOptions {
// SetProtocolVersion sets the MQTT version to be used to connect to the
// broker. Legitimate values are currently 3 - MQTT 3.1 or 4 - MQTT 3.1.1
func (o *ClientOptions) SetProtocolVersion(pv uint) *ClientOptions {
if pv >= 3 && pv <= 4 {
if (pv >= 3 && pv <= 4) || (pv > 0x80) {
o.ProtocolVersion = pv
o.protocolVersionExplicit = true
}
@@ -235,7 +275,7 @@ func (o *ClientOptions) SetBinaryWill(topic string, payload []byte, qos byte, re
// SetDefaultPublishHandler sets the MessageHandler that will be called when a message
// is received that does not match any known subscriptions.
func (o *ClientOptions) SetDefaultPublishHandler(defaultHandler MessageHandler) *ClientOptions {
o.DefaultPublishHander = defaultHandler
o.DefaultPublishHandler = defaultHandler
return o
}
@@ -291,3 +331,10 @@ func (o *ClientOptions) SetMessageChannelDepth(s uint) *ClientOptions {
o.MessageChannelDepth = s
return o
}
// SetHTTPHeaders sets the additional HTTP headers that will be sent in the WebSocket
// opening handshake.
func (o *ClientOptions) SetHTTPHeaders(h http.Header) *ClientOptions {
o.HTTPHeaders = h
return o
}

View File

@@ -16,6 +16,7 @@ package mqtt
import (
"crypto/tls"
"net/http"
"net/url"
"time"
)
@@ -25,6 +26,7 @@ type ClientOptionsReader struct {
options *ClientOptions
}
//Servers returns a slice of the servers defined in the clientoptions
func (r *ClientOptionsReader) Servers() []*url.URL {
s := make([]*url.URL, len(r.options.Servers))
@@ -36,21 +38,31 @@ func (r *ClientOptionsReader) Servers() []*url.URL {
return s
}
//ResumeSubs returns true if resuming stored (un)sub is enabled
func (r *ClientOptionsReader) ResumeSubs() bool {
s := r.options.ResumeSubs
return s
}
//ClientID returns the set client id
func (r *ClientOptionsReader) ClientID() string {
s := r.options.ClientID
return s
}
//Username returns the set username
func (r *ClientOptionsReader) Username() string {
s := r.options.Username
return s
}
//Password returns the set password
func (r *ClientOptionsReader) Password() string {
s := r.options.Password
return s
}
//CleanSession returns whether Cleansession is set
func (r *ClientOptionsReader) CleanSession() bool {
s := r.options.CleanSession
return s
@@ -91,13 +103,13 @@ func (r *ClientOptionsReader) ProtocolVersion() uint {
return s
}
func (r *ClientOptionsReader) TLSConfig() tls.Config {
func (r *ClientOptionsReader) TLSConfig() *tls.Config {
s := r.options.TLSConfig
return s
}
func (r *ClientOptionsReader) KeepAlive() time.Duration {
s := r.options.KeepAlive
s := time.Duration(r.options.KeepAlive * int64(time.Second))
return s
}
@@ -130,3 +142,8 @@ func (r *ClientOptionsReader) MessageChannelDepth() uint {
s := r.options.MessageChannelDepth
return s
}
func (r *ClientOptionsReader) HTTPHeaders() http.Header {
h := r.options.HTTPHeaders
return h
}

View File

@@ -16,87 +16,51 @@ package mqtt
import (
"errors"
"sync/atomic"
"time"
"github.com/eclipse/paho.mqtt.golang/packets"
)
func keepalive(c *client) {
defer c.workers.Done()
DEBUG.Println(PNG, "keepalive starting")
var checkInterval int64
var pingSent time.Time
receiveInterval := c.options.KeepAlive + (1 * time.Second)
pingTimer := timer{Timer: time.NewTimer(c.options.KeepAlive)}
receiveTimer := timer{Timer: time.NewTimer(receiveInterval)}
pingRespTimer := timer{Timer: time.NewTimer(c.options.PingTimeout)}
if c.options.KeepAlive > 10 {
checkInterval = 5
} else {
checkInterval = c.options.KeepAlive / 2
}
pingRespTimer.Stop()
intervalTicker := time.NewTicker(time.Duration(checkInterval * int64(time.Second)))
defer intervalTicker.Stop()
for {
select {
case <-c.stop:
DEBUG.Println(PNG, "keepalive stopped")
c.workers.Done()
return
case <-pingTimer.C:
sendPing(&pingTimer, &pingRespTimer, c)
case <-c.keepaliveReset:
DEBUG.Println(NET, "resetting ping timer")
pingTimer.Reset(c.options.KeepAlive)
case <-c.pingResp:
DEBUG.Println(NET, "resetting ping timeout timer")
pingRespTimer.Stop()
pingTimer.Reset(c.options.KeepAlive)
receiveTimer.Reset(receiveInterval)
case <-c.packetResp:
DEBUG.Println(NET, "resetting receive timer")
receiveTimer.Reset(receiveInterval)
case <-receiveTimer.C:
receiveTimer.SetRead(true)
receiveTimer.Reset(receiveInterval)
sendPing(&pingTimer, &pingRespTimer, c)
case <-pingRespTimer.C:
pingRespTimer.SetRead(true)
CRITICAL.Println(PNG, "pingresp not received, disconnecting")
c.workers.Done()
c.internalConnLost(errors.New("pingresp not received, disconnecting"))
pingTimer.Stop()
return
case <-intervalTicker.C:
DEBUG.Println(PNG, "ping check", time.Now().Unix()-atomic.LoadInt64(&c.lastSent))
if time.Now().Unix()-atomic.LoadInt64(&c.lastSent) >= c.options.KeepAlive || time.Now().Unix()-atomic.LoadInt64(&c.lastReceived) >= c.options.KeepAlive {
if atomic.LoadInt32(&c.pingOutstanding) == 0 {
DEBUG.Println(PNG, "keepalive sending ping")
ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
//We don't want to wait behind large messages being sent, the Write call
//will block until it it able to send the packet.
atomic.StoreInt32(&c.pingOutstanding, 1)
ping.Write(c.conn)
atomic.StoreInt64(&c.lastSent, time.Now().Unix())
pingSent = time.Now()
}
}
if atomic.LoadInt32(&c.pingOutstanding) > 0 && time.Now().Sub(pingSent) >= c.options.PingTimeout {
CRITICAL.Println(PNG, "pingresp not received, disconnecting")
c.errors <- errors.New("pingresp not received, disconnecting")
return
}
}
}
}
type timer struct {
*time.Timer
readFrom bool
}
func (t *timer) SetRead(v bool) {
t.readFrom = v
}
func (t *timer) Stop() bool {
defer t.SetRead(true)
if !t.Timer.Stop() && !t.readFrom {
<-t.C
return false
}
return true
}
func (t *timer) Reset(d time.Duration) bool {
defer t.SetRead(false)
t.Stop()
return t.Timer.Reset(d)
}
func sendPing(pt *timer, rt *timer, c *client) {
pt.SetRead(true)
DEBUG.Println(PNG, "keepalive sending ping")
ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
//We don't want to wait behind large messages being sent, the Write call
//will block until it it able to send the packet.
ping.Write(c.conn)
rt.Reset(c.options.PingTimeout)
}

View File

@@ -57,12 +57,23 @@ func match(route []string, topic []string) bool {
if (route[0] == "+") || (route[0] == topic[0]) {
return match(route[1:], topic[1:])
}
return false
}
func routeIncludesTopic(route, topic string) bool {
return match(strings.Split(route, "/"), strings.Split(topic, "/"))
return match(routeSplit(route), strings.Split(topic, "/"))
}
// removes $share and sharename when splitting the route to allow
// shared subscription routes to correctly match the topic
func routeSplit(route string) []string {
var result []string
if strings.HasPrefix(route, "$share") {
result = strings.Split(route, "/")[2:]
} else {
result = strings.Split(route, "/")
}
return result
}
// match takes the topic string of the published message and does a basic compare to the
@@ -135,24 +146,39 @@ func (r *router) matchAndDispatch(messages <-chan *packets.PublishPacket, order
case message := <-messages:
sent := false
r.RLock()
m := messageFromPublish(message, client.ackFunc(message))
handlers := []MessageHandler{}
for e := r.routes.Front(); e != nil; e = e.Next() {
if e.Value.(*route).match(message.TopicName) {
if order {
e.Value.(*route).callback(client, messageFromPublish(message))
handlers = append(handlers, e.Value.(*route).callback)
} else {
go e.Value.(*route).callback(client, messageFromPublish(message))
hd := e.Value.(*route).callback
go func() {
hd(client, m)
m.Ack()
}()
}
sent = true
}
}
if !sent && r.defaultHandler != nil {
if order {
r.defaultHandler(client, messageFromPublish(message))
handlers = append(handlers, r.defaultHandler)
} else {
go r.defaultHandler(client, messageFromPublish(message))
go func() {
r.defaultHandler(client, m)
m.Ack()
}()
}
}
r.RUnlock()
for _, handler := range handlers {
func() {
handler(client, m)
m.Ack()
}()
}
case <-r.stop:
return
}

View File

@@ -50,6 +50,16 @@ func mIDFromKey(key string) uint16 {
return uint16(i)
}
// Return true if key prefix is outbound
func isKeyOutbound(key string) bool {
return key[:2] == outboundPrefix
}
// Return true if key prefix is inbound
func isKeyInbound(key string) bool {
return key[:2] == inboundPrefix
}
// Return a string of the form "i.[id]"
func inboundKeyFromMID(id uint16) string {
return fmt.Sprintf("%s%d", inboundPrefix, id)

View File

@@ -19,24 +19,33 @@ import (
"github.com/eclipse/paho.mqtt.golang/packets"
)
//PacketAndToken is a struct that contains both a ControlPacket and a
//Token. This struct is passed via channels between the client interface
//code and the underlying code responsible for sending and receiving
//MQTT messages.
// PacketAndToken is a struct that contains both a ControlPacket and a
// Token. This struct is passed via channels between the client interface
// code and the underlying code responsible for sending and receiving
// MQTT messages.
type PacketAndToken struct {
p packets.ControlPacket
t Token
t tokenCompletor
}
//Token defines the interface for the tokens used to indicate when
//actions have completed.
// Token defines the interface for the tokens used to indicate when
// actions have completed.
type Token interface {
Wait() bool
WaitTimeout(time.Duration) bool
flowComplete()
Error() error
}
type TokenErrorSetter interface {
setError(error)
}
type tokenCompletor interface {
Token
TokenErrorSetter
flowComplete()
}
type baseToken struct {
m sync.RWMutex
complete chan struct{}
@@ -56,25 +65,34 @@ func (b *baseToken) Wait() bool {
return b.ready
}
// WaitTimeout takes a time in ms to wait for the flow associated with the
// WaitTimeout takes a time.Duration to wait for the flow associated with the
// Token to complete, returns true if it returned before the timeout or
// returns false if the timeout occurred. In the case of a timeout the Token
// does not have an error set in case the caller wishes to wait again
func (b *baseToken) WaitTimeout(d time.Duration) bool {
b.m.Lock()
defer b.m.Unlock()
if !b.ready {
timer := time.NewTimer(d)
select {
case <-b.complete:
b.ready = true
case <-time.After(d):
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
}
}
return b.ready
}
func (b *baseToken) flowComplete() {
close(b.complete)
select {
case <-b.complete:
default:
close(b.complete)
}
}
func (b *baseToken) Error() error {
@@ -83,7 +101,12 @@ func (b *baseToken) Error() error {
return b.err
}
func newToken(tType byte) Token {
func (b *baseToken) setError(e error) {
b.err = e
b.flowComplete()
}
func newToken(tType byte) tokenCompletor {
switch tType {
case packets.Connect:
return &ConnectToken{baseToken: baseToken{complete: make(chan struct{})}}
@@ -99,59 +122,68 @@ func newToken(tType byte) Token {
return nil
}
//ConnectToken is an extension of Token containing the extra fields
//required to provide information about calls to Connect()
// ConnectToken is an extension of Token containing the extra fields
// required to provide information about calls to Connect()
type ConnectToken struct {
baseToken
returnCode byte
returnCode byte
sessionPresent bool
}
//ReturnCode returns the acknowlegement code in the connack sent
//in response to a Connect()
// ReturnCode returns the acknowlegement code in the connack sent
// in response to a Connect()
func (c *ConnectToken) ReturnCode() byte {
c.m.RLock()
defer c.m.RUnlock()
return c.returnCode
}
//PublishToken is an extension of Token containing the extra fields
//required to provide information about calls to Publish()
// SessionPresent returns a bool representing the value of the
// session present field in the connack sent in response to a Connect()
func (c *ConnectToken) SessionPresent() bool {
c.m.RLock()
defer c.m.RUnlock()
return c.sessionPresent
}
// PublishToken is an extension of Token containing the extra fields
// required to provide information about calls to Publish()
type PublishToken struct {
baseToken
messageID uint16
}
//MessageID returns the MQTT message ID that was assigned to the
//Publish packet when it was sent to the broker
// MessageID returns the MQTT message ID that was assigned to the
// Publish packet when it was sent to the broker
func (p *PublishToken) MessageID() uint16 {
return p.messageID
}
//SubscribeToken is an extension of Token containing the extra fields
//required to provide information about calls to Subscribe()
// SubscribeToken is an extension of Token containing the extra fields
// required to provide information about calls to Subscribe()
type SubscribeToken struct {
baseToken
subs []string
subResult map[string]byte
}
//Result returns a map of topics that were subscribed to along with
//the matching return code from the broker. This is either the Qos
//value of the subscription or an error code.
// Result returns a map of topics that were subscribed to along with
// the matching return code from the broker. This is either the Qos
// value of the subscription or an error code.
func (s *SubscribeToken) Result() map[string]byte {
s.m.RLock()
defer s.m.RUnlock()
return s.subResult
}
//UnsubscribeToken is an extension of Token containing the extra fields
//required to provide information about calls to Unsubscribe()
// UnsubscribeToken is an extension of Token containing the extra fields
// required to provide information about calls to Unsubscribe()
type UnsubscribeToken struct {
baseToken
}
//DisconnectToken is an extension of Token containing the extra fields
//required to provide information about calls to Disconnect()
// DisconnectToken is an extension of Token containing the extra fields
// required to provide information about calls to Disconnect()
type DisconnectToken struct {
baseToken
}

View File

@@ -14,23 +14,27 @@
package mqtt
import (
"io/ioutil"
"log"
type (
// Logger interface allows implementations to provide to this package any
// object that implements the methods defined in it.
Logger interface {
Println(v ...interface{})
Printf(format string, v ...interface{})
}
// NOOPLogger implements the logger that does not perform any operation
// by default. This allows us to efficiently discard the unwanted messages.
NOOPLogger struct{}
)
func (NOOPLogger) Println(v ...interface{}) {}
func (NOOPLogger) Printf(format string, v ...interface{}) {}
// Internal levels of library output that are initialised to not print
// anything but can be overridden by programmer
var (
ERROR *log.Logger
CRITICAL *log.Logger
WARN *log.Logger
DEBUG *log.Logger
ERROR Logger = NOOPLogger{}
CRITICAL Logger = NOOPLogger{}
WARN Logger = NOOPLogger{}
DEBUG Logger = NOOPLogger{}
)
func init() {
ERROR = log.New(ioutil.Discard, "", 0)
CRITICAL = log.New(ioutil.Discard, "", 0)
WARN = log.New(ioutil.Discard, "", 0)
DEBUG = log.New(ioutil.Discard, "", 0)
}

168
vendor/golang.org/x/net/internal/socks/client.go generated vendored Normal file
View File

@@ -0,0 +1,168 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socks
import (
"context"
"errors"
"io"
"net"
"strconv"
"time"
)
var (
noDeadline = time.Time{}
aLongTimeAgo = time.Unix(1, 0)
)
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
host, port, err := splitHostPort(address)
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(noDeadline)
}
if ctx != context.Background() {
errCh := make(chan error, 1)
done := make(chan struct{})
defer func() {
close(done)
if ctxErr == nil {
ctxErr = <-errCh
}
}()
go func() {
select {
case <-ctx.Done():
c.SetDeadline(aLongTimeAgo)
errCh <- ctx.Err()
case <-done:
errCh <- nil
}
}()
}
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
b = append(b, Version5)
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
b = append(b, 1, byte(AuthMethodNotRequired))
} else {
ams := d.AuthMethods
if len(ams) > 255 {
return nil, errors.New("too many authentication methods")
}
b = append(b, byte(len(ams)))
for _, am := range ams {
b = append(b, byte(am))
}
}
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
am := AuthMethod(b[1])
if am == AuthMethodNoAcceptableMethods {
return nil, errors.New("no acceptable authentication methods")
}
if d.Authenticate != nil {
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
return
}
}
b = b[:0]
b = append(b, Version5, byte(d.cmd), 0)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, AddrTypeIPv4)
b = append(b, ip4...)
} else if ip6 := ip.To16(); ip6 != nil {
b = append(b, AddrTypeIPv6)
b = append(b, ip6...)
} else {
return nil, errors.New("unknown address type")
}
} else {
if len(host) > 255 {
return nil, errors.New("FQDN too long")
}
b = append(b, AddrTypeFQDN)
b = append(b, byte(len(host)))
b = append(b, host...)
}
b = append(b, byte(port>>8), byte(port))
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
return nil, errors.New("unknown error " + cmdErr.String())
}
if b[2] != 0 {
return nil, errors.New("non-zero reserved field")
}
l := 2
var a Addr
switch b[3] {
case AddrTypeIPv4:
l += net.IPv4len
a.IP = make(net.IP, net.IPv4len)
case AddrTypeIPv6:
l += net.IPv6len
a.IP = make(net.IP, net.IPv6len)
case AddrTypeFQDN:
if _, err := io.ReadFull(c, b[:1]); err != nil {
return nil, err
}
l += int(b[0])
default:
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
}
if cap(b) < l {
b = make([]byte, l)
} else {
b = b[:l]
}
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
return
}
if a.IP != nil {
copy(a.IP, b)
} else {
a.Name = string(b[:len(b)-2])
}
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
return &a, nil
}
func splitHostPort(address string) (string, int, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", 0, err
}
portnum, err := strconv.Atoi(port)
if err != nil {
return "", 0, err
}
if 1 > portnum || portnum > 0xffff {
return "", 0, errors.New("port number out of range " + port)
}
return host, portnum, nil
}

317
vendor/golang.org/x/net/internal/socks/socks.go generated vendored Normal file
View File

@@ -0,0 +1,317 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package socks provides a SOCKS version 5 client implementation.
//
// SOCKS protocol version 5 is defined in RFC 1928.
// Username/Password authentication for SOCKS version 5 is defined in
// RFC 1929.
package socks
import (
"context"
"errors"
"io"
"net"
"strconv"
)
// A Command represents a SOCKS command.
type Command int
func (cmd Command) String() string {
switch cmd {
case CmdConnect:
return "socks connect"
case cmdBind:
return "socks bind"
default:
return "socks " + strconv.Itoa(int(cmd))
}
}
// An AuthMethod represents a SOCKS authentication method.
type AuthMethod int
// A Reply represents a SOCKS command reply code.
type Reply int
func (code Reply) String() string {
switch code {
case StatusSucceeded:
return "succeeded"
case 0x01:
return "general SOCKS server failure"
case 0x02:
return "connection not allowed by ruleset"
case 0x03:
return "network unreachable"
case 0x04:
return "host unreachable"
case 0x05:
return "connection refused"
case 0x06:
return "TTL expired"
case 0x07:
return "command not supported"
case 0x08:
return "address type not supported"
default:
return "unknown code: " + strconv.Itoa(int(code))
}
}
// Wire protocol constants.
const (
Version5 = 0x05
AddrTypeIPv4 = 0x01
AddrTypeFQDN = 0x03
AddrTypeIPv6 = 0x04
CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
cmdBind Command = 0x02 // establishes a passive-open forward proxy connection
AuthMethodNotRequired AuthMethod = 0x00 // no authentication required
AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password
AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods
StatusSucceeded Reply = 0x00
)
// An Addr represents a SOCKS-specific address.
// Either Name or IP is used exclusively.
type Addr struct {
Name string // fully-qualified domain name
IP net.IP
Port int
}
func (a *Addr) Network() string { return "socks" }
func (a *Addr) String() string {
if a == nil {
return "<nil>"
}
port := strconv.Itoa(a.Port)
if a.IP == nil {
return net.JoinHostPort(a.Name, port)
}
return net.JoinHostPort(a.IP.String(), port)
}
// A Conn represents a forward proxy connection.
type Conn struct {
net.Conn
boundAddr net.Addr
}
// BoundAddr returns the address assigned by the proxy server for
// connecting to the command target address from the proxy server.
func (c *Conn) BoundAddr() net.Addr {
if c == nil {
return nil
}
return c.boundAddr
}
// A Dialer holds SOCKS-specific options.
type Dialer struct {
cmd Command // either CmdConnect or cmdBind
proxyNetwork string // network between a proxy server and a client
proxyAddress string // proxy server address
// ProxyDial specifies the optional dial function for
// establishing the transport connection.
ProxyDial func(context.Context, string, string) (net.Conn, error)
// AuthMethods specifies the list of request authention
// methods.
// If empty, SOCKS client requests only AuthMethodNotRequired.
AuthMethods []AuthMethod
// Authenticate specifies the optional authentication
// function. It must be non-nil when AuthMethods is not empty.
// It must return an error when the authentication is failed.
Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
}
// DialContext connects to the provided address on the provided
// network.
//
// The returned error value may be a net.OpError. When the Op field of
// net.OpError contains "socks", the Source field contains a proxy
// server address and the Addr field contains a command target
// address.
//
// See func Dial of the net package of standard library for a
// description of the network and address parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
} else {
var dd net.Dialer
c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
a, err := d.connect(ctx, c, address)
if err != nil {
c.Close()
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return &Conn{Conn: c, boundAddr: a}, nil
}
// DialWithConn initiates a connection from SOCKS server to the target
// network and address using the connection c that is already
// connected to the SOCKS server.
//
// It returns the connection's local address assigned by the SOCKS
// server.
func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
a, err := d.connect(ctx, c, address)
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return a, nil
}
// Dial connects to the provided address on the provided network.
//
// Unlike DialContext, it returns a raw transport connection instead
// of a forward proxy connection.
//
// Deprecated: Use DialContext or DialWithConn instead.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
} else {
c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
c.Close()
return nil, err
}
return c, nil
}
func (d *Dialer) validateTarget(network, address string) error {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return errors.New("network not implemented")
}
switch d.cmd {
case CmdConnect, cmdBind:
default:
return errors.New("command not implemented")
}
return nil
}
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
for i, s := range []string{d.proxyAddress, address} {
host, port, err := splitHostPort(s)
if err != nil {
return nil, nil, err
}
a := &Addr{Port: port}
a.IP = net.ParseIP(host)
if a.IP == nil {
a.Name = host
}
if i == 0 {
proxy = a
} else {
dst = a
}
}
return
}
// NewDialer returns a new Dialer that dials through the provided
// proxy server's network and address.
func NewDialer(network, address string) *Dialer {
return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
}
const (
authUsernamePasswordVersion = 0x01
authStatusSucceeded = 0x00
)
// UsernamePassword are the credentials for the username/password
// authentication method.
type UsernamePassword struct {
Username string
Password string
}
// Authenticate authenticates a pair of username and password with the
// proxy server.
func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
switch auth {
case AuthMethodNotRequired:
return nil
case AuthMethodUsernamePassword:
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
return errors.New("invalid username/password")
}
b := []byte{authUsernamePasswordVersion}
b = append(b, byte(len(up.Username)))
b = append(b, up.Username...)
b = append(b, byte(len(up.Password)))
b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if
// necessary
if _, err := rw.Write(b); err != nil {
return err
}
if _, err := io.ReadFull(rw, b[:2]); err != nil {
return err
}
if b[0] != authUsernamePasswordVersion {
return errors.New("invalid username/password version")
}
if b[1] != authStatusSucceeded {
return errors.New("username/password authentication failed")
}
return nil
}
return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
}

18
vendor/golang.org/x/net/proxy/direct.go generated vendored Normal file
View File

@@ -0,0 +1,18 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
)
type direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var Direct = direct{}
func (direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}

140
vendor/golang.org/x/net/proxy/per_host.go generated vendored Normal file
View File

@@ -0,0 +1,140 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"strings"
)
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type PerHost struct {
def, bypass Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
return &PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *PerHost) dialerForRequest(host string) Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}

139
vendor/golang.org/x/net/proxy/proxy.go generated vendored Normal file
View File

@@ -0,0 +1,139 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package proxy provides support for a variety of protocols to proxy network
// data.
package proxy // import "golang.org/x/net/proxy"
import (
"errors"
"net"
"net/url"
"os"
"sync"
)
// A Dialer is a means to establish a connection.
type Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func FromEnvironment() Dialer {
allProxy := allProxyEnv.Get()
if len(allProxy) == 0 {
return Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return Direct
}
proxy, err := FromURL(proxyURL, Direct)
if err != nil {
return Direct
}
noProxy := noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := NewPerHost(proxy, Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
if proxySchemes == nil {
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
}
proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
var auth *Auth
if u.User != nil {
auth = new(Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5", "socks5h":
addr := u.Hostname()
port := u.Port()
if port == "" {
port = "1080"
}
return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxySchemes != nil {
if f, ok := proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
allProxyEnv = &envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
noProxyEnv = &envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type envOnce struct {
names []string
once sync.Once
val string
}
func (e *envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// reset is used by tests
func (e *envOnce) reset() {
e.once = sync.Once{}
e.val = ""
}

36
vendor/golang.org/x/net/proxy/socks5.go generated vendored Normal file
View File

@@ -0,0 +1,36 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"context"
"net"
"golang.org/x/net/internal/socks"
)
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given
// address with an optional username and password.
// See RFC 1928 and RFC 1929.
func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
d := socks.NewDialer(network, address)
if forward != nil {
d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) {
return forward.Dial(network, address)
}
}
if auth != nil {
up := socks.UsernamePassword{
Username: auth.User,
Password: auth.Password,
}
d.AuthMethods = []socks.AuthMethod{
socks.AuthMethodNotRequired,
socks.AuthMethodUsernamePassword,
}
d.Authenticate = up.Authenticate
}
return d, nil
}

18
vendor/vendor.json vendored
View File

@@ -257,10 +257,10 @@
"revisionTime": "2016-08-05T00:47:13Z"
},
{
"checksumSHA1": "LHwm1G0lFyQ3X0iHZ/MOYnAtjd4=",
"checksumSHA1": "VPQEUnKynzSvgugyvL3rxjduidg=",
"path": "github.com/eclipse/paho.mqtt.golang",
"revision": "d06cc70ac43d625e602946b5ff2346ddebb768e4",
"revisionTime": "2017-06-02T16:30:32Z"
"revision": "379fd9f99ba5b1f02c9fffb5e5952416ef9301dc",
"revisionTime": "2018-11-29T14:54:54Z"
},
{
"checksumSHA1": "g2AaO9VMaxzFQZkrCfxBUV3yAyM=",
@@ -1200,6 +1200,12 @@
"revision": "da118f7b8e5954f39d0d2130ab35d4bf0e3cb344",
"revisionTime": "2017-04-23T14:02:46Z"
},
{
"checksumSHA1": "f3Y7JIZH61oMmp8nphqe8Mg+XoU=",
"path": "golang.org/x/net/internal/socks",
"revision": "1e06a53dbb7e2ed46e91183f219db23c6943c532",
"revisionTime": "2018-12-20T03:20:21Z"
},
{
"checksumSHA1": "UxahDzW2v4mf/+aFxruuupaoIwo=",
"path": "golang.org/x/net/internal/timeseries",
@@ -1212,6 +1218,12 @@
"revision": "da118f7b8e5954f39d0d2130ab35d4bf0e3cb344",
"revisionTime": "2017-04-23T14:02:46Z"
},
{
"checksumSHA1": "mCMW3hvbWFW1k5il9yyO7ELOdws=",
"path": "golang.org/x/net/proxy",
"revision": "1e06a53dbb7e2ed46e91183f219db23c6943c532",
"revisionTime": "2018-12-20T03:20:21Z"
},
{
"checksumSHA1": "9EZG3s2eOREO7WkBvigjk57wK/8=",
"path": "golang.org/x/net/trace",