mark pubsub type safe via generics (#15961)

This commit is contained in:
Klaus Post 2022-10-28 19:55:42 +02:00 committed by GitHub
parent 6d22e74d11
commit 71954faa3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 91 additions and 105 deletions

View File

@ -50,7 +50,6 @@ import (
"github.com/minio/minio/internal/kms" "github.com/minio/minio/internal/kms"
"github.com/minio/minio/internal/logger" "github.com/minio/minio/internal/logger"
"github.com/minio/minio/internal/logger/message/log" "github.com/minio/minio/internal/logger/message/log"
"github.com/minio/minio/internal/pubsub"
iampolicy "github.com/minio/pkg/iam/policy" iampolicy "github.com/minio/pkg/iam/policy"
xnet "github.com/minio/pkg/net" xnet "github.com/minio/pkg/net"
"github.com/secure-io/sio-go" "github.com/secure-io/sio-go"
@ -1504,16 +1503,12 @@ func (a adminAPIHandlers) TraceHandler(w http.ResponseWriter, r *http.Request) {
// Trace Publisher and peer-trace-client uses nonblocking send and hence does not wait for slow receivers. // Trace Publisher and peer-trace-client uses nonblocking send and hence does not wait for slow receivers.
// Use buffered channel to take care of burst sends or slow w.Write() // Use buffered channel to take care of burst sends or slow w.Write()
traceCh := make(chan pubsub.Maskable, 4000) traceCh := make(chan madmin.TraceInfo, 4000)
peers, _ := newPeerRestClients(globalEndpoints) peers, _ := newPeerRestClients(globalEndpoints)
mask := pubsub.MaskFromMaskable(traceOpts.TraceTypes())
err = globalTrace.Subscribe(mask, traceCh, ctx.Done(), func(entry pubsub.Maskable) bool { err = globalTrace.Subscribe(traceOpts.TraceTypes(), traceCh, ctx.Done(), func(entry madmin.TraceInfo) bool {
if e, ok := entry.(madmin.TraceInfo); ok { return shouldTrace(entry, traceOpts)
return shouldTrace(e, traceOpts)
}
return false
}) })
if err != nil { if err != nil {
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErr(ErrSlowDown), r.URL) writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErr(ErrSlowDown), r.URL)
@ -1585,7 +1580,7 @@ func (a adminAPIHandlers) ConsoleLogHandler(w http.ResponseWriter, r *http.Reque
setEventStreamHeaders(w) setEventStreamHeaders(w)
logCh := make(chan pubsub.Maskable, 4000) logCh := make(chan log.Info, 4000)
peers, _ := newPeerRestClients(globalEndpoints) peers, _ := newPeerRestClients(globalEndpoints)
@ -1611,9 +1606,8 @@ func (a adminAPIHandlers) ConsoleLogHandler(w http.ResponseWriter, r *http.Reque
for { for {
select { select {
case entry := <-logCh: case log := <-logCh:
log, ok := entry.(log.Info) if log.SendLog(node, logKind) {
if ok && log.SendLog(node, logKind) {
if err := enc.Encode(log); err != nil { if err := enc.Encode(log); err != nil {
return return
} }

View File

@ -22,7 +22,6 @@ import (
"runtime" "runtime"
"github.com/minio/madmin-go" "github.com/minio/madmin-go"
"github.com/minio/minio/internal/pubsub"
) )
// healTask represents what to heal along with options // healTask represents what to heal along with options
@ -54,7 +53,7 @@ type healRoutine struct {
func activeListeners() int { func activeListeners() int {
// Bucket notification and http trace are not costly, it is okay to ignore them // Bucket notification and http trace are not costly, it is okay to ignore them
// while counting the number of concurrent connections // while counting the number of concurrent connections
return int(globalHTTPListen.NumSubscribers(pubsub.MaskAll)) + int(globalTrace.NumSubscribers(pubsub.MaskAll)) return int(globalHTTPListen.Subscribers()) + int(globalTrace.Subscribers())
} }
func waitForLowHTTPReq() { func waitForLowHTTPReq() {

View File

@ -37,7 +37,7 @@ const defaultLogBufferCount = 10000
// HTTPConsoleLoggerSys holds global console logger state // HTTPConsoleLoggerSys holds global console logger state
type HTTPConsoleLoggerSys struct { type HTTPConsoleLoggerSys struct {
sync.RWMutex sync.RWMutex
pubsub *pubsub.PubSub pubsub *pubsub.PubSub[log.Info, madmin.LogMask]
console *console.Target console *console.Target
nodeName string nodeName string
logBuf *ring.Ring logBuf *ring.Ring
@ -46,9 +46,8 @@ type HTTPConsoleLoggerSys struct {
// NewConsoleLogger - creates new HTTPConsoleLoggerSys with all nodes subscribed to // NewConsoleLogger - creates new HTTPConsoleLoggerSys with all nodes subscribed to
// the console logging pub sub system // the console logging pub sub system
func NewConsoleLogger(ctx context.Context) *HTTPConsoleLoggerSys { func NewConsoleLogger(ctx context.Context) *HTTPConsoleLoggerSys {
ps := pubsub.New(8)
return &HTTPConsoleLoggerSys{ return &HTTPConsoleLoggerSys{
pubsub: ps, pubsub: pubsub.New[log.Info, madmin.LogMask](8),
console: console.New(), console: console.New(),
logBuf: ring.New(defaultLogBufferCount), logBuf: ring.New(defaultLogBufferCount),
} }
@ -72,11 +71,11 @@ func (sys *HTTPConsoleLoggerSys) SetNodeName(nodeName string) {
// HasLogListeners returns true if console log listeners are registered // HasLogListeners returns true if console log listeners are registered
// for this node or peers // for this node or peers
func (sys *HTTPConsoleLoggerSys) HasLogListeners() bool { func (sys *HTTPConsoleLoggerSys) HasLogListeners() bool {
return sys != nil && sys.pubsub.NumSubscribers(madmin.LogMaskAll) > 0 return sys != nil && sys.pubsub.Subscribers() > 0
} }
// Subscribe starts console logging for this node. // Subscribe starts console logging for this node.
func (sys *HTTPConsoleLoggerSys) Subscribe(subCh chan pubsub.Maskable, doneCh <-chan struct{}, node string, last int, logKind madmin.LogMask, filter func(entry pubsub.Maskable) bool) error { func (sys *HTTPConsoleLoggerSys) Subscribe(subCh chan log.Info, doneCh <-chan struct{}, node string, last int, logKind madmin.LogMask, filter func(entry log.Info) bool) error {
// Enable console logging for remote client. // Enable console logging for remote client.
if !sys.HasLogListeners() { if !sys.HasLogListeners() {
logger.AddSystemTarget(sys) logger.AddSystemTarget(sys)
@ -116,7 +115,7 @@ func (sys *HTTPConsoleLoggerSys) Subscribe(subCh chan pubsub.Maskable, doneCh <-
} }
} }
} }
return sys.pubsub.Subscribe(pubsub.MaskFromMaskable(madmin.LogMaskAll), subCh, doneCh, filter) return sys.pubsub.Subscribe(madmin.LogMaskAll, subCh, doneCh, filter)
} }
// Init if HTTPConsoleLoggerSys is valid, always returns nil right now // Init if HTTPConsoleLoggerSys is valid, always returns nil right now

View File

@ -29,6 +29,7 @@ import (
"github.com/minio/minio/internal/event" "github.com/minio/minio/internal/event"
xhttp "github.com/minio/minio/internal/http" xhttp "github.com/minio/minio/internal/http"
"github.com/minio/minio/internal/logger" "github.com/minio/minio/internal/logger"
"github.com/minio/minio/internal/pubsub"
"github.com/minio/pkg/bucket/policy" "github.com/minio/pkg/bucket/policy"
) )
@ -321,7 +322,7 @@ func sendEvent(args eventArgs) {
crypto.RemoveSensitiveEntries(args.Object.UserDefined) crypto.RemoveSensitiveEntries(args.Object.UserDefined)
crypto.RemoveInternalEntries(args.Object.UserDefined) crypto.RemoveInternalEntries(args.Object.UserDefined)
if globalHTTPListen.NumSubscribers(args.EventName) > 0 { if globalHTTPListen.NumSubscribers(pubsub.MaskFromMaskable(args.EventName)) > 0 {
globalHTTPListen.Publish(args.ToEvent(false)) globalHTTPListen.Publish(args.ToEvent(false))
} }

View File

@ -27,7 +27,8 @@ import (
"time" "time"
"github.com/minio/console/restapi" "github.com/minio/console/restapi"
minio "github.com/minio/minio-go/v7" "github.com/minio/madmin-go"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/set" "github.com/minio/minio-go/v7/pkg/set"
"github.com/minio/minio/internal/bucket/bandwidth" "github.com/minio/minio/internal/bucket/bandwidth"
"github.com/minio/minio/internal/config" "github.com/minio/minio/internal/config"
@ -220,11 +221,10 @@ var (
// global Trace system to send HTTP request/response // global Trace system to send HTTP request/response
// and Storage/OS calls info to registered listeners. // and Storage/OS calls info to registered listeners.
globalTrace = pubsub.New(8) globalTrace = pubsub.New[madmin.TraceInfo, madmin.TraceType](8)
// global Listen system to send S3 API events to registered listeners // global Listen system to send S3 API events to registered listeners
// Objects are expected to be event.Event globalHTTPListen = pubsub.New[event.Event, pubsub.Mask](0)
globalHTTPListen = pubsub.New(0)
// global console system to send console logs to // global console system to send console logs to
// registered listeners // registered listeners

View File

@ -125,15 +125,11 @@ func (api objectAPIHandlers) ListenNotificationHandler(w http.ResponseWriter, r
// Listen Publisher and peer-listen-client uses nonblocking send and hence does not wait for slow receivers. // Listen Publisher and peer-listen-client uses nonblocking send and hence does not wait for slow receivers.
// Use buffered channel to take care of burst sends or slow w.Write() // Use buffered channel to take care of burst sends or slow w.Write()
listenCh := make(chan pubsub.Maskable, 4000) listenCh := make(chan event.Event, 4000)
peers, _ := newPeerRestClients(globalEndpoints) peers, _ := newPeerRestClients(globalEndpoints)
err := globalHTTPListen.Subscribe(mask, listenCh, ctx.Done(), func(evI pubsub.Maskable) bool { err := globalHTTPListen.Subscribe(mask, listenCh, ctx.Done(), func(ev event.Event) bool {
ev, ok := evI.(event.Event)
if !ok {
return false
}
if ev.S3.Bucket.Name != "" && bucketName != "" { if ev.S3.Bucket.Name != "" && bucketName != "" {
if ev.S3.Bucket.Name != bucketName { if ev.S3.Bucket.Name != bucketName {
return false return false
@ -161,16 +157,9 @@ func (api objectAPIHandlers) ListenNotificationHandler(w http.ResponseWriter, r
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
for { for {
select { select {
case evI := <-listenCh: case ev := <-listenCh:
ev, ok := evI.(event.Event) if err := enc.Encode(struct{ Records []event.Event }{[]event.Event{ev}}); err != nil {
if ok { return
if err := enc.Encode(struct{ Records []event.Event }{[]event.Event{ev}}); err != nil {
return
}
} else {
if _, err := w.Write([]byte(" ")); err != nil {
return
}
} }
if len(listenCh) == 0 { if len(listenCh) == 0 {
// Flush if nothing is queued // Flush if nothing is queued

View File

@ -34,7 +34,7 @@ import (
"github.com/minio/minio/internal/http" "github.com/minio/minio/internal/http"
xhttp "github.com/minio/minio/internal/http" xhttp "github.com/minio/minio/internal/http"
"github.com/minio/minio/internal/logger" "github.com/minio/minio/internal/logger"
"github.com/minio/minio/internal/pubsub" "github.com/minio/minio/internal/logger/message/log"
"github.com/minio/minio/internal/rest" "github.com/minio/minio/internal/rest"
xnet "github.com/minio/pkg/net" xnet "github.com/minio/pkg/net"
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
@ -583,7 +583,7 @@ func (client *peerRESTClient) LoadTransitionTierConfig(ctx context.Context) erro
return nil return nil
} }
func (client *peerRESTClient) doTrace(traceCh chan<- pubsub.Maskable, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) { func (client *peerRESTClient) doTrace(traceCh chan<- madmin.TraceInfo, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) {
values := make(url.Values) values := make(url.Values)
traceOpts.AddParams(values) traceOpts.AddParams(values)
@ -624,7 +624,7 @@ func (client *peerRESTClient) doTrace(traceCh chan<- pubsub.Maskable, doneCh <-c
} }
} }
func (client *peerRESTClient) doListen(listenCh chan<- pubsub.Maskable, doneCh <-chan struct{}, v url.Values) { func (client *peerRESTClient) doListen(listenCh chan<- event.Event, doneCh <-chan struct{}, v url.Values) {
// To cancel the REST request in case doneCh gets closed. // To cancel the REST request in case doneCh gets closed.
ctx, cancel := context.WithCancel(GlobalContext) ctx, cancel := context.WithCancel(GlobalContext)
@ -663,7 +663,7 @@ func (client *peerRESTClient) doListen(listenCh chan<- pubsub.Maskable, doneCh <
} }
// Listen - listen on peers. // Listen - listen on peers.
func (client *peerRESTClient) Listen(listenCh chan<- pubsub.Maskable, doneCh <-chan struct{}, v url.Values) { func (client *peerRESTClient) Listen(listenCh chan<- event.Event, doneCh <-chan struct{}, v url.Values) {
go func() { go func() {
for { for {
client.doListen(listenCh, doneCh, v) client.doListen(listenCh, doneCh, v)
@ -679,7 +679,7 @@ func (client *peerRESTClient) Listen(listenCh chan<- pubsub.Maskable, doneCh <-c
} }
// Trace - send http trace request to peer nodes // Trace - send http trace request to peer nodes
func (client *peerRESTClient) Trace(traceCh chan<- pubsub.Maskable, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) { func (client *peerRESTClient) Trace(traceCh chan<- madmin.TraceInfo, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) {
go func() { go func() {
for { for {
client.doTrace(traceCh, doneCh, traceOpts) client.doTrace(traceCh, doneCh, traceOpts)
@ -694,7 +694,7 @@ func (client *peerRESTClient) Trace(traceCh chan<- pubsub.Maskable, doneCh <-cha
}() }()
} }
func (client *peerRESTClient) doConsoleLog(logCh chan pubsub.Maskable, doneCh <-chan struct{}) { func (client *peerRESTClient) doConsoleLog(logCh chan log.Info, doneCh <-chan struct{}) {
// To cancel the REST request in case doneCh gets closed. // To cancel the REST request in case doneCh gets closed.
ctx, cancel := context.WithCancel(GlobalContext) ctx, cancel := context.WithCancel(GlobalContext)
@ -717,22 +717,20 @@ func (client *peerRESTClient) doConsoleLog(logCh chan pubsub.Maskable, doneCh <-
dec := gob.NewDecoder(respBody) dec := gob.NewDecoder(respBody)
for { for {
var lg madmin.LogInfo var lg log.Info
if err = dec.Decode(&lg); err != nil { if err = dec.Decode(&lg); err != nil {
break break
} }
if lg.DeploymentID != "" { select {
select { case logCh <- lg:
case logCh <- lg: default:
default: // Do not block on slow receivers.
// Do not block on slow receivers.
}
} }
} }
} }
// ConsoleLog - sends request to peer nodes to get console logs // ConsoleLog - sends request to peer nodes to get console logs
func (client *peerRESTClient) ConsoleLog(logCh chan pubsub.Maskable, doneCh <-chan struct{}) { func (client *peerRESTClient) ConsoleLog(logCh chan log.Info, doneCh <-chan struct{}) {
go func() { go func() {
for { for {
client.doConsoleLog(logCh, doneCh) client.doConsoleLog(logCh, doneCh)

View File

@ -35,6 +35,7 @@ import (
b "github.com/minio/minio/internal/bucket/bandwidth" b "github.com/minio/minio/internal/bucket/bandwidth"
"github.com/minio/minio/internal/event" "github.com/minio/minio/internal/event"
"github.com/minio/minio/internal/logger" "github.com/minio/minio/internal/logger"
"github.com/minio/minio/internal/logger/message/log"
"github.com/minio/minio/internal/pubsub" "github.com/minio/minio/internal/pubsub"
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
) )
@ -919,13 +920,9 @@ func (s *peerRESTServer) ListenHandler(w http.ResponseWriter, r *http.Request) {
// Listen Publisher uses nonblocking publish and hence does not wait for slow subscribers. // Listen Publisher uses nonblocking publish and hence does not wait for slow subscribers.
// Use buffered channel to take care of burst sends or slow w.Write() // Use buffered channel to take care of burst sends or slow w.Write()
ch := make(chan pubsub.Maskable, 2000) ch := make(chan event.Event, 2000)
err := globalHTTPListen.Subscribe(mask, ch, doneCh, func(evI pubsub.Maskable) bool { err := globalHTTPListen.Subscribe(mask, ch, doneCh, func(ev event.Event) bool {
ev, ok := evI.(event.Event)
if !ok {
return false
}
if ev.S3.Bucket.Name != "" && values.Get(peerRESTListenBucket) != "" { if ev.S3.Bucket.Name != "" && values.Get(peerRESTListenBucket) != "" {
if ev.S3.Bucket.Name != values.Get(peerRESTListenBucket) { if ev.S3.Bucket.Name != values.Get(peerRESTListenBucket) {
return false return false
@ -978,13 +975,9 @@ func (s *peerRESTServer) TraceHandler(w http.ResponseWriter, r *http.Request) {
// Trace Publisher uses nonblocking publish and hence does not wait for slow subscribers. // Trace Publisher uses nonblocking publish and hence does not wait for slow subscribers.
// Use buffered channel to take care of burst sends or slow w.Write() // Use buffered channel to take care of burst sends or slow w.Write()
ch := make(chan pubsub.Maskable, 2000) ch := make(chan madmin.TraceInfo, 2000)
mask := pubsub.MaskFromMaskable(traceOpts.TraceTypes()) err = globalTrace.Subscribe(traceOpts.TraceTypes(), ch, r.Context().Done(), func(entry madmin.TraceInfo) bool {
err = globalTrace.Subscribe(mask, ch, r.Context().Done(), func(entry pubsub.Maskable) bool { return shouldTrace(entry, traceOpts)
if e, ok := entry.(madmin.TraceInfo); ok {
return shouldTrace(e, traceOpts)
}
return false
}) })
if err != nil { if err != nil {
s.writeErrorResponse(w, err) s.writeErrorResponse(w, err)
@ -1125,7 +1118,7 @@ func (s *peerRESTServer) ConsoleLogHandler(w http.ResponseWriter, r *http.Reques
doneCh := make(chan struct{}) doneCh := make(chan struct{})
defer close(doneCh) defer close(doneCh)
ch := make(chan pubsub.Maskable, 2000) ch := make(chan log.Info, 2000)
err := globalConsoleSys.Subscribe(ch, doneCh, "", 0, madmin.LogMaskAll, nil) err := globalConsoleSys.Subscribe(ch, doneCh, "", 0, madmin.LogMaskAll, nil)
if err != nil { if err != nil {
s.writeErrorResponse(w, err) s.writeErrorResponse(w, err)
@ -1141,12 +1134,16 @@ func (s *peerRESTServer) ConsoleLogHandler(w http.ResponseWriter, r *http.Reques
if err := enc.Encode(entry); err != nil { if err := enc.Encode(entry); err != nil {
return return
} }
w.(http.Flusher).Flush() if len(ch) == 0 {
case <-keepAliveTicker.C: w.(http.Flusher).Flush()
if err := enc.Encode(&madmin.LogInfo{}); err != nil { }
return case <-keepAliveTicker.C:
if len(ch) == 0 {
if err := enc.Encode(&madmin.LogInfo{}); err != nil {
return
}
w.(http.Flusher).Flush()
} }
w.(http.Flusher).Flush()
case <-r.Context().Done(): case <-r.Context().Done():
return return
} }

View File

@ -67,6 +67,7 @@ const (
ObjectReplicationAll ObjectReplicationAll
ObjectRestorePostAll ObjectRestorePostAll
ObjectTransitionAll ObjectTransitionAll
Everything
) )
// The number of single names should not exceed 64. // The number of single names should not exceed 64.
@ -112,6 +113,12 @@ func (name Name) Expand() []Name {
ObjectTransitionFailed, ObjectTransitionFailed,
ObjectTransitionComplete, ObjectTransitionComplete,
} }
case Everything:
res := make([]Name, objectSingleTypesEnd-1)
for i := range res {
res[i] = Name(i + 1)
}
return res
default: default:
return []Name{name} return []Name{name}
} }

View File

@ -24,28 +24,28 @@ import (
) )
// Sub - subscriber entity. // Sub - subscriber entity.
type Sub struct { type Sub[T Maskable] struct {
ch chan Maskable ch chan T
types Mask types Mask
filter func(entry Maskable) bool filter func(entry T) bool
} }
// PubSub holds publishers and subscribers // PubSub holds publishers and subscribers
type PubSub struct { type PubSub[T Maskable, M Maskable] struct {
// atomics, keep at top: // atomics, keep at top:
types uint64 types uint64
numSubscribers int32 numSubscribers int32
maxSubscribers int32 maxSubscribers int32
// not atomics: // not atomics:
subs []*Sub subs []*Sub[T]
sync.RWMutex sync.RWMutex
} }
// Publish message to the subscribers. // Publish message to the subscribers.
// Note that publish is always nob-blocking send so that we don't block on slow receivers. // Note that publish is always nob-blocking send so that we don't block on slow receivers.
// Hence receivers should use buffered channel so as not to miss the published events. // Hence receivers should use buffered channel so as not to miss the published events.
func (ps *PubSub) Publish(item Maskable) { func (ps *PubSub[T, M]) Publish(item T) {
ps.RLock() ps.RLock()
defer ps.RUnlock() defer ps.RUnlock()
for _, sub := range ps.subs { for _, sub := range ps.subs {
@ -59,7 +59,7 @@ func (ps *PubSub) Publish(item Maskable) {
} }
// Subscribe - Adds a subscriber to pubsub system // Subscribe - Adds a subscriber to pubsub system
func (ps *PubSub) Subscribe(mask Mask, subCh chan Maskable, doneCh <-chan struct{}, filter func(entry Maskable) bool) error { func (ps *PubSub[T, M]) Subscribe(mask M, subCh chan T, doneCh <-chan struct{}, filter func(entry T) bool) error {
totalSubs := atomic.AddInt32(&ps.numSubscribers, 1) totalSubs := atomic.AddInt32(&ps.numSubscribers, 1)
if ps.maxSubscribers > 0 && totalSubs > ps.maxSubscribers { if ps.maxSubscribers > 0 && totalSubs > ps.maxSubscribers {
atomic.AddInt32(&ps.numSubscribers, -1) atomic.AddInt32(&ps.numSubscribers, -1)
@ -68,12 +68,12 @@ func (ps *PubSub) Subscribe(mask Mask, subCh chan Maskable, doneCh <-chan struct
ps.Lock() ps.Lock()
defer ps.Unlock() defer ps.Unlock()
sub := &Sub{ch: subCh, types: mask, filter: filter} sub := &Sub[T]{ch: subCh, types: Mask(mask.Mask()), filter: filter}
ps.subs = append(ps.subs, sub) ps.subs = append(ps.subs, sub)
// We hold a lock, so we are safe to update // We hold a lock, so we are safe to update
combined := Mask(atomic.LoadUint64(&ps.types)) combined := Mask(atomic.LoadUint64(&ps.types))
combined.Merge(mask) combined.Merge(Mask(mask.Mask()))
atomic.StoreUint64(&ps.types, uint64(combined)) atomic.StoreUint64(&ps.types, uint64(combined))
go func() { go func() {
@ -97,21 +97,23 @@ func (ps *PubSub) Subscribe(mask Mask, subCh chan Maskable, doneCh <-chan struct
} }
// NumSubscribers returns the number of current subscribers, // NumSubscribers returns the number of current subscribers,
// If t is non-nil, the type is checked against the active subscribed types, // The mask is checked against the active subscribed types,
// and 0 will be returned if nobody is subscribed for the type, // and 0 will be returned if nobody is subscribed for the type(s).
// otherwise the *total* number of subscribers is returned. func (ps *PubSub[T, M]) NumSubscribers(mask M) int32 {
func (ps *PubSub) NumSubscribers(m Maskable) int32 { types := Mask(atomic.LoadUint64(&ps.types))
if m != nil { if !types.Overlaps(Mask(mask.Mask())) {
types := Mask(atomic.LoadUint64(&ps.types)) return 0
if !types.Overlaps(Mask(m.Mask())) {
return 0
}
} }
return atomic.LoadInt32(&ps.numSubscribers) return atomic.LoadInt32(&ps.numSubscribers)
} }
// Subscribers returns the number of current subscribers for all types.
func (ps *PubSub[T, M]) Subscribers() int32 {
return atomic.LoadInt32(&ps.numSubscribers)
}
// New inits a PubSub system with a limit of maximum // New inits a PubSub system with a limit of maximum
// subscribers unless zero is specified // subscribers unless zero is specified
func New(maxSubscribers int32) *PubSub { func New[T Maskable, M Maskable](maxSubscribers int32) *PubSub[T, M] {
return &PubSub{maxSubscribers: maxSubscribers} return &PubSub[T, M]{maxSubscribers: maxSubscribers}
} }

View File

@ -24,7 +24,7 @@ import (
) )
func TestSubscribe(t *testing.T) { func TestSubscribe(t *testing.T) {
ps := New(2) ps := New[Maskable, Mask](2)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
ch2 := make(chan Maskable, 1) ch2 := make(chan Maskable, 1)
doneCh := make(chan struct{}) doneCh := make(chan struct{})
@ -38,21 +38,21 @@ func TestSubscribe(t *testing.T) {
ps.Lock() ps.Lock()
defer ps.Unlock() defer ps.Unlock()
if len(ps.subs) != 2 || ps.NumSubscribers(nil) != 2 { if len(ps.subs) != 2 || ps.NumSubscribers(MaskAll) != 2 || ps.Subscribers() != 2 {
t.Fatalf("expected 2 subscribers") t.Fatalf("expected 2 subscribers")
} }
} }
func TestNumSubscribersMask(t *testing.T) { func TestNumSubscribersMask(t *testing.T) {
ps := New(2) ps := New[Maskable, Mask](2)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
ch2 := make(chan Maskable, 1) ch2 := make(chan Maskable, 1)
doneCh := make(chan struct{}) doneCh := make(chan struct{})
defer close(doneCh) defer close(doneCh)
if err := ps.Subscribe(1, ch1, doneCh, nil); err != nil { if err := ps.Subscribe(Mask(1), ch1, doneCh, nil); err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if err := ps.Subscribe(2, ch2, doneCh, nil); err != nil { if err := ps.Subscribe(Mask(2), ch2, doneCh, nil); err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
ps.Lock() ps.Lock()
@ -70,7 +70,7 @@ func TestNumSubscribersMask(t *testing.T) {
if want, got := int32(2), ps.NumSubscribers(Mask(1|2)); got != want { if want, got := int32(2), ps.NumSubscribers(Mask(1|2)); got != want {
t.Fatalf("want %d subscribers, got %d", want, got) t.Fatalf("want %d subscribers, got %d", want, got)
} }
if want, got := int32(2), ps.NumSubscribers(nil); got != want { if want, got := int32(2), ps.NumSubscribers(MaskAll); got != want {
t.Fatalf("want %d subscribers, got %d", want, got) t.Fatalf("want %d subscribers, got %d", want, got)
} }
if want, got := int32(0), ps.NumSubscribers(Mask(4)); got != want { if want, got := int32(0), ps.NumSubscribers(Mask(4)); got != want {
@ -79,7 +79,7 @@ func TestNumSubscribersMask(t *testing.T) {
} }
func TestSubscribeExceedingLimit(t *testing.T) { func TestSubscribeExceedingLimit(t *testing.T) {
ps := New(2) ps := New[Maskable, Maskable](2)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
ch2 := make(chan Maskable, 1) ch2 := make(chan Maskable, 1)
ch3 := make(chan Maskable, 1) ch3 := make(chan Maskable, 1)
@ -97,7 +97,7 @@ func TestSubscribeExceedingLimit(t *testing.T) {
} }
func TestUnsubscribe(t *testing.T) { func TestUnsubscribe(t *testing.T) {
ps := New(2) ps := New[Maskable, Maskable](2)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
ch2 := make(chan Maskable, 1) ch2 := make(chan Maskable, 1)
doneCh1 := make(chan struct{}) doneCh1 := make(chan struct{})
@ -127,7 +127,7 @@ func (m maskString) Mask() uint64 {
} }
func TestPubSub(t *testing.T) { func TestPubSub(t *testing.T) {
ps := New(1) ps := New[Maskable, Maskable](1)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
doneCh1 := make(chan struct{}) doneCh1 := make(chan struct{})
defer close(doneCh1) defer close(doneCh1)
@ -143,7 +143,7 @@ func TestPubSub(t *testing.T) {
} }
func TestMultiPubSub(t *testing.T) { func TestMultiPubSub(t *testing.T) {
ps := New(2) ps := New[Maskable, Maskable](2)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
ch2 := make(chan Maskable, 1) ch2 := make(chan Maskable, 1)
doneCh := make(chan struct{}) doneCh := make(chan struct{})
@ -165,7 +165,7 @@ func TestMultiPubSub(t *testing.T) {
} }
func TestMultiPubSubMask(t *testing.T) { func TestMultiPubSubMask(t *testing.T) {
ps := New(3) ps := New[Maskable, Maskable](3)
ch1 := make(chan Maskable, 1) ch1 := make(chan Maskable, 1)
ch2 := make(chan Maskable, 1) ch2 := make(chan Maskable, 1)
ch3 := make(chan Maskable, 1) ch3 := make(chan Maskable, 1)