mirror of
https://github.com/minio/minio.git
synced 2024-12-25 06:35:56 -05:00
Simplify HTTP trace related code (#7833)
This commit is contained in:
parent
c1d2b3d5c3
commit
183ec094c4
@ -43,6 +43,7 @@ import (
|
|||||||
"github.com/minio/minio/pkg/mem"
|
"github.com/minio/minio/pkg/mem"
|
||||||
xnet "github.com/minio/minio/pkg/net"
|
xnet "github.com/minio/minio/pkg/net"
|
||||||
"github.com/minio/minio/pkg/quick"
|
"github.com/minio/minio/pkg/quick"
|
||||||
|
trace "github.com/minio/minio/pkg/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -1483,11 +1484,6 @@ func (a adminAPIHandlers) TraceHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if globalTrace == nil {
|
|
||||||
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErr(ErrServerNotInitialized), r.URL)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Avoid reusing tcp connection if read timeout is hit
|
// Avoid reusing tcp connection if read timeout is hit
|
||||||
// This is needed to make r.Context().Done() work as
|
// This is needed to make r.Context().Done() work as
|
||||||
// expected in case of read timeout
|
// expected in case of read timeout
|
||||||
@ -1496,14 +1492,33 @@ func (a adminAPIHandlers) TraceHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
doneCh := make(chan struct{})
|
doneCh := make(chan struct{})
|
||||||
defer close(doneCh)
|
defer close(doneCh)
|
||||||
|
|
||||||
traceCh := globalTrace.Trace(doneCh, trcAll)
|
// Trace Publisher and peer-trace-client uses nonblocking send and hence does not wait for slow receivers.
|
||||||
|
// Use buffered channel to take care of burst sends or slow w.Write()
|
||||||
|
traceCh := make(chan interface{}, 4000)
|
||||||
|
|
||||||
|
filter := func(entry interface{}) bool {
|
||||||
|
if trcAll {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trcInfo := entry.(trace.Info)
|
||||||
|
return !strings.HasPrefix(trcInfo.ReqInfo.Path, minioReservedBucketPath)
|
||||||
|
}
|
||||||
|
remoteHosts := getRemoteHosts(globalEndpoints)
|
||||||
|
peers, err := getRestClients(remoteHosts)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
globalHTTPTrace.Subscribe(traceCh, doneCh, filter)
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
peer.Trace(traceCh, doneCh, trcAll)
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case entry := <-traceCh:
|
case entry := <-traceCh:
|
||||||
if _, err := w.Write(entry); err != nil {
|
if err := enc.Encode(entry); err != nil {
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := w.Write([]byte("\n")); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.(http.Flusher).Flush()
|
w.(http.Flusher).Flush()
|
||||||
|
@ -158,9 +158,6 @@ func StartGateway(ctx *cli.Context, gw Gateway) {
|
|||||||
registerSTSRouter(router)
|
registerSTSRouter(router)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize globalTrace system
|
|
||||||
globalTrace = NewTraceSys(context.Background(), globalEndpoints)
|
|
||||||
|
|
||||||
enableConfigOps := globalEtcdClient != nil && gatewayName == "nas"
|
enableConfigOps := globalEtcdClient != nil && gatewayName == "nas"
|
||||||
enableIAMOps := globalEtcdClient != nil
|
enableIAMOps := globalEtcdClient != nil
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ import (
|
|||||||
"github.com/minio/minio/pkg/dns"
|
"github.com/minio/minio/pkg/dns"
|
||||||
iampolicy "github.com/minio/minio/pkg/iam/policy"
|
iampolicy "github.com/minio/minio/pkg/iam/policy"
|
||||||
"github.com/minio/minio/pkg/iam/validator"
|
"github.com/minio/minio/pkg/iam/validator"
|
||||||
|
"github.com/minio/minio/pkg/pubsub"
|
||||||
)
|
)
|
||||||
|
|
||||||
// minio configuration related constants.
|
// minio configuration related constants.
|
||||||
@ -161,7 +162,7 @@ var (
|
|||||||
|
|
||||||
// global Trace system to send HTTP request/response logs to
|
// global Trace system to send HTTP request/response logs to
|
||||||
// registered listeners
|
// registered listeners
|
||||||
globalTrace *HTTPTraceSys
|
globalHTTPTrace = pubsub.New()
|
||||||
|
|
||||||
globalEndpoints EndpointList
|
globalEndpoints EndpointList
|
||||||
|
|
||||||
|
@ -326,24 +326,24 @@ func extractPostPolicyFormValues(ctx context.Context, form *multipart.Form) (fil
|
|||||||
// Log headers and body.
|
// Log headers and body.
|
||||||
func httpTraceAll(f http.HandlerFunc) http.HandlerFunc {
|
func httpTraceAll(f http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !globalTrace.HasTraceListeners() {
|
if !globalHTTPTrace.HasSubscribers() {
|
||||||
f.ServeHTTP(w, r)
|
f.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
trace := Trace(f, true, w, r)
|
trace := Trace(f, true, w, r)
|
||||||
globalTrace.Publish(trace)
|
globalHTTPTrace.Publish(trace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log only the headers.
|
// Log only the headers.
|
||||||
func httpTraceHdrs(f http.HandlerFunc) http.HandlerFunc {
|
func httpTraceHdrs(f http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !globalTrace.HasTraceListeners() {
|
if !globalHTTPTrace.HasSubscribers() {
|
||||||
f.ServeHTTP(w, r)
|
f.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
trace := Trace(f, false, w, r)
|
trace := Trace(f, false, w, r)
|
||||||
globalTrace.Publish(trace)
|
globalHTTPTrace.Publish(trace)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
115
cmd/httptrace.go
115
cmd/httptrace.go
@ -1,115 +0,0 @@
|
|||||||
/*
|
|
||||||
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/minio/minio/cmd/logger"
|
|
||||||
"github.com/minio/minio/pkg/pubsub"
|
|
||||||
"github.com/minio/minio/pkg/trace"
|
|
||||||
)
|
|
||||||
|
|
||||||
//HTTPTraceSys holds global trace state
|
|
||||||
type HTTPTraceSys struct {
|
|
||||||
peers []*peerRESTClient
|
|
||||||
pubsub *pubsub.PubSub
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTraceSys - creates new HTTPTraceSys with all nodes subscribed to
|
|
||||||
// the trace pub sub system
|
|
||||||
func NewTraceSys(ctx context.Context, endpoints EndpointList) *HTTPTraceSys {
|
|
||||||
remoteHosts := getRemoteHosts(endpoints)
|
|
||||||
remoteClients, err := getRestClients(remoteHosts)
|
|
||||||
if err != nil {
|
|
||||||
logger.FatalIf(err, "Unable to start httptrace sub system")
|
|
||||||
}
|
|
||||||
|
|
||||||
ps := pubsub.New()
|
|
||||||
return &HTTPTraceSys{
|
|
||||||
remoteClients, ps,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasTraceListeners returns true if trace listeners are registered
|
|
||||||
// for this node or peers
|
|
||||||
func (sys *HTTPTraceSys) HasTraceListeners() bool {
|
|
||||||
return sys != nil && sys.pubsub.HasSubscribers()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish - publishes trace message to the http trace pubsub system
|
|
||||||
func (sys *HTTPTraceSys) Publish(traceMsg trace.Info) {
|
|
||||||
sys.pubsub.Publish(traceMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trace writes http trace to writer
|
|
||||||
func (sys *HTTPTraceSys) Trace(doneCh chan struct{}, trcAll bool) chan []byte {
|
|
||||||
traceCh := make(chan []byte)
|
|
||||||
go func() {
|
|
||||||
defer close(traceCh)
|
|
||||||
|
|
||||||
var wg = &sync.WaitGroup{}
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
|
||||||
ch := sys.pubsub.Subscribe()
|
|
||||||
defer sys.pubsub.Unsubscribe(ch)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case entry := <-ch:
|
|
||||||
trcInfo := entry.(trace.Info)
|
|
||||||
path := strings.TrimPrefix(trcInfo.ReqInfo.Path, "/")
|
|
||||||
// omit inter-node traffic if trcAll is false
|
|
||||||
if !trcAll && strings.HasPrefix(path, minioReservedBucket) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
buf.Reset()
|
|
||||||
enc := json.NewEncoder(buf)
|
|
||||||
enc.SetEscapeHTML(false)
|
|
||||||
if err := enc.Encode(trcInfo); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
traceCh <- buf.Bytes()
|
|
||||||
case <-doneCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for _, peer := range sys.peers {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(peer *peerRESTClient) {
|
|
||||||
defer wg.Done()
|
|
||||||
ch, err := peer.Trace(doneCh, trcAll)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for entry := range ch {
|
|
||||||
traceCh <- entry
|
|
||||||
}
|
|
||||||
}(peer)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}()
|
|
||||||
return traceCh
|
|
||||||
}
|
|
@ -17,7 +17,6 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
@ -34,6 +33,7 @@ import (
|
|||||||
"github.com/minio/minio/pkg/madmin"
|
"github.com/minio/minio/pkg/madmin"
|
||||||
xnet "github.com/minio/minio/pkg/net"
|
xnet "github.com/minio/minio/pkg/net"
|
||||||
"github.com/minio/minio/pkg/policy"
|
"github.com/minio/minio/pkg/policy"
|
||||||
|
trace "github.com/minio/minio/pkg/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// client to talk to peer Nodes.
|
// client to talk to peer Nodes.
|
||||||
@ -435,52 +435,59 @@ func (client *peerRESTClient) BackgroundHealStatus() (madmin.BgHealState, error)
|
|||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trace - send http trace request to peer nodes
|
func (client *peerRESTClient) doTrace(traceCh chan interface{}, doneCh chan struct{}, trcAll bool) {
|
||||||
func (client *peerRESTClient) Trace(doneCh chan struct{}, trcAll bool) (chan []byte, error) {
|
|
||||||
ch := make(chan []byte)
|
|
||||||
go func() {
|
|
||||||
cleanupFn := func(cancel context.CancelFunc, ch chan []byte, respBody io.ReadCloser) {
|
|
||||||
close(ch)
|
|
||||||
if cancel != nil {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
http.DrainBody(respBody)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
values := make(url.Values)
|
values := make(url.Values)
|
||||||
values.Set(peerRESTTraceAll, strconv.FormatBool(trcAll))
|
values.Set(peerRESTTraceAll, strconv.FormatBool(trcAll))
|
||||||
// get cancellation context to properly unsubscribe peers
|
|
||||||
|
// To cancel the REST request in case doneCh gets closed.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
respBody, err := client.callWithContext(ctx, peerRESTMethodTrace, values, nil, -1)
|
|
||||||
if err != nil {
|
cancelCh := make(chan struct{})
|
||||||
//retry
|
defer close(cancelCh)
|
||||||
time.Sleep(5 * time.Second)
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-doneCh:
|
case <-doneCh:
|
||||||
cleanupFn(cancel, ch, respBody)
|
case <-cancelCh:
|
||||||
return
|
// There was an error in the REST request.
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
continue
|
|
||||||
}
|
|
||||||
bio := bufio.NewScanner(respBody)
|
|
||||||
go func() {
|
|
||||||
<-doneCh
|
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
// Unmarshal each line, returns marshaled values.
|
|
||||||
for bio.Scan() {
|
respBody, err := client.callWithContext(ctx, peerRESTMethodTrace, values, nil, -1)
|
||||||
ch <- bio.Bytes()
|
defer http.DrainBody(respBody)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dec := gob.NewDecoder(respBody)
|
||||||
|
for {
|
||||||
|
var info trace.Info
|
||||||
|
if err = dec.Decode(&info); err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
case traceCh <- info:
|
||||||
|
default:
|
||||||
|
// Do not block on slow receivers.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace - send http trace request to peer nodes
|
||||||
|
func (client *peerRESTClient) Trace(traceCh chan interface{}, doneCh chan struct{}, trcAll bool) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
client.doTrace(traceCh, doneCh, trcAll)
|
||||||
|
select {
|
||||||
case <-doneCh:
|
case <-doneCh:
|
||||||
cleanupFn(cancel, ch, respBody)
|
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
// There was error in the REST request, retry after sometime as probably the peer is down.
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRemoteHosts(endpoints EndpointList) []*xnet.Host {
|
func getRemoteHosts(endpoints EndpointList) []*xnet.Host {
|
||||||
|
@ -19,7 +19,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -679,32 +678,33 @@ func (s *peerRESTServer) TraceHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Header().Set("Connection", "close")
|
w.Header().Set("Connection", "close")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.(http.Flusher).Flush()
|
w.(http.Flusher).Flush()
|
||||||
ch := globalTrace.pubsub.Subscribe()
|
|
||||||
defer globalTrace.pubsub.Unsubscribe(ch)
|
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
filter := func(entry interface{}) bool {
|
||||||
enc.SetEscapeHTML(false)
|
if trcAll {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trcInfo := entry.(trace.Info)
|
||||||
|
return !strings.HasPrefix(trcInfo.ReqInfo.Path, minioReservedBucketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
defer close(doneCh)
|
||||||
|
|
||||||
|
// Trace Publisher uses nonblocking publish and hence does not wait for slow subscribers.
|
||||||
|
// Use buffered channel to take care of burst sends or slow w.Write()
|
||||||
|
ch := make(chan interface{}, 2000)
|
||||||
|
globalHTTPTrace.Subscribe(ch, doneCh, filter)
|
||||||
|
|
||||||
|
enc := gob.NewEncoder(w)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case entry := <-ch:
|
case entry := <-ch:
|
||||||
trcInfo := entry.(trace.Info)
|
if err := enc.Encode(entry); err != nil {
|
||||||
path := strings.TrimPrefix(trcInfo.ReqInfo.Path, "/")
|
|
||||||
// omit inter-node traffic if trcAll is false
|
|
||||||
if !trcAll && strings.HasPrefix(path, minioReservedBucket) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := enc.Encode(trcInfo); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.Write([]byte("\n")); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.(http.Flusher).Flush()
|
w.(http.Flusher).Flush()
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -294,9 +294,6 @@ func serverMain(ctx *cli.Context) {
|
|||||||
globalSweepHealState = initHealState()
|
globalSweepHealState = initHealState()
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize globalTrace system
|
|
||||||
globalTrace = NewTraceSys(context.Background(), globalEndpoints)
|
|
||||||
|
|
||||||
// Configure server.
|
// Configure server.
|
||||||
var handler http.Handler
|
var handler http.Handler
|
||||||
handler, err = configureServerHandler(globalEndpoints)
|
handler, err = configureServerHandler(globalEndpoints)
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
package madmin
|
package madmin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -35,7 +33,7 @@ type TraceInfo struct {
|
|||||||
|
|
||||||
// Trace - listen on http trace notifications.
|
// Trace - listen on http trace notifications.
|
||||||
func (adm AdminClient) Trace(allTrace bool, doneCh <-chan struct{}) <-chan TraceInfo {
|
func (adm AdminClient) Trace(allTrace bool, doneCh <-chan struct{}) <-chan TraceInfo {
|
||||||
traceInfoCh := make(chan TraceInfo, 1)
|
traceInfoCh := make(chan TraceInfo)
|
||||||
// Only success, start a routine to start reading line by line.
|
// Only success, start a routine to start reading line by line.
|
||||||
go func(traceInfoCh chan<- TraceInfo) {
|
go func(traceInfoCh chan<- TraceInfo) {
|
||||||
defer close(traceInfoCh)
|
defer close(traceInfoCh)
|
||||||
@ -58,30 +56,16 @@ func (adm AdminClient) Trace(allTrace bool, doneCh <-chan struct{}) <-chan Trace
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize a new bufio scanner, to read line by line.
|
dec := json.NewDecoder(resp.Body)
|
||||||
bio := bufio.NewScanner(resp.Body)
|
for {
|
||||||
|
var info trace.Info
|
||||||
// Close the response body.
|
if err = dec.Decode(&info); err != nil {
|
||||||
defer resp.Body.Close()
|
break
|
||||||
|
|
||||||
// Unmarshal each line, returns marshaled values.
|
|
||||||
for bio.Scan() {
|
|
||||||
var traceRec trace.Info
|
|
||||||
if err = json.Unmarshal(bio.Bytes(), &traceRec); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-doneCh:
|
case <-doneCh:
|
||||||
return
|
return
|
||||||
case traceInfoCh <- TraceInfo{Trace: traceRec}:
|
case traceInfoCh <- TraceInfo{Trace: info}:
|
||||||
}
|
|
||||||
}
|
|
||||||
// Look for any underlying errors.
|
|
||||||
if err = bio.Err(); err != nil {
|
|
||||||
// For an unexpected connection drop from server, we close the body
|
|
||||||
// and re-connect.
|
|
||||||
if err == io.ErrUnexpectedEOF {
|
|
||||||
resp.Body.Close()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,64 +20,65 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Sub - subscriber entity.
|
||||||
|
type Sub struct {
|
||||||
|
ch chan interface{}
|
||||||
|
filter func(entry interface{}) bool
|
||||||
|
}
|
||||||
|
|
||||||
// PubSub holds publishers and subscribers
|
// PubSub holds publishers and subscribers
|
||||||
type PubSub struct {
|
type PubSub struct {
|
||||||
subs []chan interface{}
|
subs []*Sub
|
||||||
pub chan interface{}
|
sync.RWMutex
|
||||||
mutex sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// process item to subscribers.
|
// Publish message to the subscribers.
|
||||||
func (ps *PubSub) process() {
|
// Note that publish is always nob-blocking send so that we don't block on slow receivers.
|
||||||
for item := range ps.pub {
|
// Hence receivers should use buffered channel so as not to miss the published events.
|
||||||
ps.mutex.Lock()
|
|
||||||
for _, sub := range ps.subs {
|
|
||||||
go func(s chan interface{}) {
|
|
||||||
s <- item
|
|
||||||
}(sub)
|
|
||||||
}
|
|
||||||
ps.mutex.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish message to pubsub system
|
|
||||||
func (ps *PubSub) Publish(item interface{}) {
|
func (ps *PubSub) Publish(item interface{}) {
|
||||||
ps.pub <- item
|
ps.RLock()
|
||||||
|
defer ps.RUnlock()
|
||||||
|
|
||||||
|
for _, sub := range ps.subs {
|
||||||
|
if sub.filter(item) {
|
||||||
|
select {
|
||||||
|
case sub.ch <- item:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe - Adds a subscriber to pubsub system
|
// Subscribe - Adds a subscriber to pubsub system
|
||||||
func (ps *PubSub) Subscribe() chan interface{} {
|
func (ps *PubSub) Subscribe(subCh chan interface{}, doneCh chan struct{}, filter func(entry interface{}) bool) {
|
||||||
ps.mutex.Lock()
|
ps.Lock()
|
||||||
defer ps.mutex.Unlock()
|
defer ps.Unlock()
|
||||||
ch := make(chan interface{})
|
|
||||||
ps.subs = append(ps.subs, ch)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsubscribe removes current subscriber
|
sub := &Sub{subCh, filter}
|
||||||
func (ps *PubSub) Unsubscribe(ch chan interface{}) {
|
ps.subs = append(ps.subs, sub)
|
||||||
ps.mutex.Lock()
|
|
||||||
defer ps.mutex.Unlock()
|
|
||||||
|
|
||||||
for i, sub := range ps.subs {
|
go func() {
|
||||||
if sub == ch {
|
<-doneCh
|
||||||
close(ch)
|
|
||||||
|
ps.Lock()
|
||||||
|
defer ps.Unlock()
|
||||||
|
|
||||||
|
for i, s := range ps.subs {
|
||||||
|
if s == sub {
|
||||||
ps.subs = append(ps.subs[:i], ps.subs[i+1:]...)
|
ps.subs = append(ps.subs[:i], ps.subs[i+1:]...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasSubscribers returns true if pubsub system has subscribers
|
// HasSubscribers returns true if pubsub system has subscribers
|
||||||
func (ps *PubSub) HasSubscribers() bool {
|
func (ps *PubSub) HasSubscribers() bool {
|
||||||
ps.mutex.Lock()
|
ps.RLock()
|
||||||
defer ps.mutex.Unlock()
|
defer ps.RUnlock()
|
||||||
return len(ps.subs) > 0
|
return len(ps.subs) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// New inits a PubSub system
|
// New inits a PubSub system
|
||||||
func New() *PubSub {
|
func New() *PubSub {
|
||||||
ps := &PubSub{}
|
return &PubSub{}
|
||||||
ps.pub = make(chan interface{})
|
|
||||||
go ps.process()
|
|
||||||
return ps
|
|
||||||
}
|
}
|
||||||
|
@ -19,12 +19,19 @@ package pubsub
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSubscribe(t *testing.T) {
|
func TestSubscribe(t *testing.T) {
|
||||||
ps := New()
|
ps := New()
|
||||||
ps.Subscribe()
|
ch1 := make(chan interface{}, 1)
|
||||||
ps.Subscribe()
|
ch2 := make(chan interface{}, 1)
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
defer close(doneCh)
|
||||||
|
ps.Subscribe(ch1, doneCh, nil)
|
||||||
|
ps.Subscribe(ch2, doneCh, nil)
|
||||||
|
ps.Lock()
|
||||||
|
defer ps.Unlock()
|
||||||
if len(ps.subs) != 2 {
|
if len(ps.subs) != 2 {
|
||||||
t.Errorf("expected 2 subscribers")
|
t.Errorf("expected 2 subscribers")
|
||||||
}
|
}
|
||||||
@ -32,20 +39,33 @@ func TestSubscribe(t *testing.T) {
|
|||||||
|
|
||||||
func TestUnsubscribe(t *testing.T) {
|
func TestUnsubscribe(t *testing.T) {
|
||||||
ps := New()
|
ps := New()
|
||||||
c1 := ps.Subscribe()
|
ch1 := make(chan interface{}, 1)
|
||||||
ps.Subscribe()
|
ch2 := make(chan interface{}, 1)
|
||||||
ps.Unsubscribe(c1)
|
doneCh1 := make(chan struct{})
|
||||||
|
doneCh2 := make(chan struct{})
|
||||||
|
ps.Subscribe(ch1, doneCh1, nil)
|
||||||
|
ps.Subscribe(ch2, doneCh2, nil)
|
||||||
|
|
||||||
|
close(doneCh1)
|
||||||
|
// Allow for the above statement to take effect.
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
ps.Lock()
|
||||||
if len(ps.subs) != 1 {
|
if len(ps.subs) != 1 {
|
||||||
t.Errorf("expected 1 subscriber")
|
t.Errorf("expected 1 subscriber")
|
||||||
}
|
}
|
||||||
|
ps.Unlock()
|
||||||
|
close(doneCh2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPubSub(t *testing.T) {
|
func TestPubSub(t *testing.T) {
|
||||||
ps := New()
|
ps := New()
|
||||||
c1 := ps.Subscribe()
|
ch1 := make(chan interface{}, 1)
|
||||||
|
doneCh1 := make(chan struct{})
|
||||||
|
defer close(doneCh1)
|
||||||
|
ps.Subscribe(ch1, doneCh1, func(entry interface{}) bool { return true })
|
||||||
val := "hello"
|
val := "hello"
|
||||||
ps.Publish(val)
|
ps.Publish(val)
|
||||||
msg := <-c1
|
msg := <-ch1
|
||||||
if msg != "hello" {
|
if msg != "hello" {
|
||||||
t.Errorf(fmt.Sprintf("expected %s , found %s", val, msg))
|
t.Errorf(fmt.Sprintf("expected %s , found %s", val, msg))
|
||||||
}
|
}
|
||||||
@ -53,13 +73,17 @@ func TestPubSub(t *testing.T) {
|
|||||||
|
|
||||||
func TestMultiPubSub(t *testing.T) {
|
func TestMultiPubSub(t *testing.T) {
|
||||||
ps := New()
|
ps := New()
|
||||||
c1 := ps.Subscribe()
|
ch1 := make(chan interface{}, 1)
|
||||||
c2 := ps.Subscribe()
|
ch2 := make(chan interface{}, 1)
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
defer close(doneCh)
|
||||||
|
ps.Subscribe(ch1, doneCh, func(entry interface{}) bool { return true })
|
||||||
|
ps.Subscribe(ch2, doneCh, func(entry interface{}) bool { return true })
|
||||||
val := "hello"
|
val := "hello"
|
||||||
ps.Publish(val)
|
ps.Publish(val)
|
||||||
|
|
||||||
msg1 := <-c1
|
msg1 := <-ch1
|
||||||
msg2 := <-c2
|
msg2 := <-ch2
|
||||||
if msg1 != "hello" && msg2 != "hello" {
|
if msg1 != "hello" && msg2 != "hello" {
|
||||||
t.Errorf(fmt.Sprintf("expected both subscribers to have%s , found %s and %s", val, msg1, msg2))
|
t.Errorf(fmt.Sprintf("expected both subscribers to have%s , found %s and %s", val, msg1, msg2))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user