mirror of
https://github.com/minio/minio.git
synced 2024-12-23 21:55:53 -05:00
mark pubsub type safe via generics (#15961)
This commit is contained in:
parent
6d22e74d11
commit
71954faa3a
@ -50,7 +50,6 @@ import (
|
||||
"github.com/minio/minio/internal/kms"
|
||||
"github.com/minio/minio/internal/logger"
|
||||
"github.com/minio/minio/internal/logger/message/log"
|
||||
"github.com/minio/minio/internal/pubsub"
|
||||
iampolicy "github.com/minio/pkg/iam/policy"
|
||||
xnet "github.com/minio/pkg/net"
|
||||
"github.com/secure-io/sio-go"
|
||||
@ -1504,16 +1503,12 @@ func (a adminAPIHandlers) TraceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Trace Publisher and peer-trace-client uses nonblocking send and hence does not wait for slow receivers.
|
||||
// Use buffered channel to take care of burst sends or slow w.Write()
|
||||
traceCh := make(chan pubsub.Maskable, 4000)
|
||||
traceCh := make(chan madmin.TraceInfo, 4000)
|
||||
|
||||
peers, _ := newPeerRestClients(globalEndpoints)
|
||||
mask := pubsub.MaskFromMaskable(traceOpts.TraceTypes())
|
||||
|
||||
err = globalTrace.Subscribe(mask, traceCh, ctx.Done(), func(entry pubsub.Maskable) bool {
|
||||
if e, ok := entry.(madmin.TraceInfo); ok {
|
||||
return shouldTrace(e, traceOpts)
|
||||
}
|
||||
return false
|
||||
err = globalTrace.Subscribe(traceOpts.TraceTypes(), traceCh, ctx.Done(), func(entry madmin.TraceInfo) bool {
|
||||
return shouldTrace(entry, traceOpts)
|
||||
})
|
||||
if err != nil {
|
||||
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErr(ErrSlowDown), r.URL)
|
||||
@ -1585,7 +1580,7 @@ func (a adminAPIHandlers) ConsoleLogHandler(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
setEventStreamHeaders(w)
|
||||
|
||||
logCh := make(chan pubsub.Maskable, 4000)
|
||||
logCh := make(chan log.Info, 4000)
|
||||
|
||||
peers, _ := newPeerRestClients(globalEndpoints)
|
||||
|
||||
@ -1611,9 +1606,8 @@ func (a adminAPIHandlers) ConsoleLogHandler(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
for {
|
||||
select {
|
||||
case entry := <-logCh:
|
||||
log, ok := entry.(log.Info)
|
||||
if ok && log.SendLog(node, logKind) {
|
||||
case log := <-logCh:
|
||||
if log.SendLog(node, logKind) {
|
||||
if err := enc.Encode(log); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -22,7 +22,6 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/minio/madmin-go"
|
||||
"github.com/minio/minio/internal/pubsub"
|
||||
)
|
||||
|
||||
// healTask represents what to heal along with options
|
||||
@ -54,7 +53,7 @@ type healRoutine struct {
|
||||
func activeListeners() int {
|
||||
// Bucket notification and http trace are not costly, it is okay to ignore them
|
||||
// while counting the number of concurrent connections
|
||||
return int(globalHTTPListen.NumSubscribers(pubsub.MaskAll)) + int(globalTrace.NumSubscribers(pubsub.MaskAll))
|
||||
return int(globalHTTPListen.Subscribers()) + int(globalTrace.Subscribers())
|
||||
}
|
||||
|
||||
func waitForLowHTTPReq() {
|
||||
|
@ -37,7 +37,7 @@ const defaultLogBufferCount = 10000
|
||||
// HTTPConsoleLoggerSys holds global console logger state
|
||||
type HTTPConsoleLoggerSys struct {
|
||||
sync.RWMutex
|
||||
pubsub *pubsub.PubSub
|
||||
pubsub *pubsub.PubSub[log.Info, madmin.LogMask]
|
||||
console *console.Target
|
||||
nodeName string
|
||||
logBuf *ring.Ring
|
||||
@ -46,9 +46,8 @@ type HTTPConsoleLoggerSys struct {
|
||||
// NewConsoleLogger - creates new HTTPConsoleLoggerSys with all nodes subscribed to
|
||||
// the console logging pub sub system
|
||||
func NewConsoleLogger(ctx context.Context) *HTTPConsoleLoggerSys {
|
||||
ps := pubsub.New(8)
|
||||
return &HTTPConsoleLoggerSys{
|
||||
pubsub: ps,
|
||||
pubsub: pubsub.New[log.Info, madmin.LogMask](8),
|
||||
console: console.New(),
|
||||
logBuf: ring.New(defaultLogBufferCount),
|
||||
}
|
||||
@ -72,11 +71,11 @@ func (sys *HTTPConsoleLoggerSys) SetNodeName(nodeName string) {
|
||||
// HasLogListeners returns true if console log listeners are registered
|
||||
// for this node or peers
|
||||
func (sys *HTTPConsoleLoggerSys) HasLogListeners() bool {
|
||||
return sys != nil && sys.pubsub.NumSubscribers(madmin.LogMaskAll) > 0
|
||||
return sys != nil && sys.pubsub.Subscribers() > 0
|
||||
}
|
||||
|
||||
// Subscribe starts console logging for this node.
|
||||
func (sys *HTTPConsoleLoggerSys) Subscribe(subCh chan pubsub.Maskable, doneCh <-chan struct{}, node string, last int, logKind madmin.LogMask, filter func(entry pubsub.Maskable) bool) error {
|
||||
func (sys *HTTPConsoleLoggerSys) Subscribe(subCh chan log.Info, doneCh <-chan struct{}, node string, last int, logKind madmin.LogMask, filter func(entry log.Info) bool) error {
|
||||
// Enable console logging for remote client.
|
||||
if !sys.HasLogListeners() {
|
||||
logger.AddSystemTarget(sys)
|
||||
@ -116,7 +115,7 @@ func (sys *HTTPConsoleLoggerSys) Subscribe(subCh chan pubsub.Maskable, doneCh <-
|
||||
}
|
||||
}
|
||||
}
|
||||
return sys.pubsub.Subscribe(pubsub.MaskFromMaskable(madmin.LogMaskAll), subCh, doneCh, filter)
|
||||
return sys.pubsub.Subscribe(madmin.LogMaskAll, subCh, doneCh, filter)
|
||||
}
|
||||
|
||||
// Init if HTTPConsoleLoggerSys is valid, always returns nil right now
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
"github.com/minio/minio/internal/event"
|
||||
xhttp "github.com/minio/minio/internal/http"
|
||||
"github.com/minio/minio/internal/logger"
|
||||
"github.com/minio/minio/internal/pubsub"
|
||||
"github.com/minio/pkg/bucket/policy"
|
||||
)
|
||||
|
||||
@ -321,7 +322,7 @@ func sendEvent(args eventArgs) {
|
||||
crypto.RemoveSensitiveEntries(args.Object.UserDefined)
|
||||
crypto.RemoveInternalEntries(args.Object.UserDefined)
|
||||
|
||||
if globalHTTPListen.NumSubscribers(args.EventName) > 0 {
|
||||
if globalHTTPListen.NumSubscribers(pubsub.MaskFromMaskable(args.EventName)) > 0 {
|
||||
globalHTTPListen.Publish(args.ToEvent(false))
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/minio/console/restapi"
|
||||
minio "github.com/minio/minio-go/v7"
|
||||
"github.com/minio/madmin-go"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/set"
|
||||
"github.com/minio/minio/internal/bucket/bandwidth"
|
||||
"github.com/minio/minio/internal/config"
|
||||
@ -220,11 +221,10 @@ var (
|
||||
|
||||
// global Trace system to send HTTP request/response
|
||||
// and Storage/OS calls info to registered listeners.
|
||||
globalTrace = pubsub.New(8)
|
||||
globalTrace = pubsub.New[madmin.TraceInfo, madmin.TraceType](8)
|
||||
|
||||
// global Listen system to send S3 API events to registered listeners
|
||||
// Objects are expected to be event.Event
|
||||
globalHTTPListen = pubsub.New(0)
|
||||
globalHTTPListen = pubsub.New[event.Event, pubsub.Mask](0)
|
||||
|
||||
// global console system to send console logs to
|
||||
// registered listeners
|
||||
|
@ -125,15 +125,11 @@ func (api objectAPIHandlers) ListenNotificationHandler(w http.ResponseWriter, r
|
||||
|
||||
// Listen Publisher and peer-listen-client uses nonblocking send and hence does not wait for slow receivers.
|
||||
// Use buffered channel to take care of burst sends or slow w.Write()
|
||||
listenCh := make(chan pubsub.Maskable, 4000)
|
||||
listenCh := make(chan event.Event, 4000)
|
||||
|
||||
peers, _ := newPeerRestClients(globalEndpoints)
|
||||
|
||||
err := globalHTTPListen.Subscribe(mask, listenCh, ctx.Done(), func(evI pubsub.Maskable) bool {
|
||||
ev, ok := evI.(event.Event)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
err := globalHTTPListen.Subscribe(mask, listenCh, ctx.Done(), func(ev event.Event) bool {
|
||||
if ev.S3.Bucket.Name != "" && bucketName != "" {
|
||||
if ev.S3.Bucket.Name != bucketName {
|
||||
return false
|
||||
@ -161,16 +157,9 @@ func (api objectAPIHandlers) ListenNotificationHandler(w http.ResponseWriter, r
|
||||
enc := json.NewEncoder(w)
|
||||
for {
|
||||
select {
|
||||
case evI := <-listenCh:
|
||||
ev, ok := evI.(event.Event)
|
||||
if ok {
|
||||
if err := enc.Encode(struct{ Records []event.Event }{[]event.Event{ev}}); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if _, err := w.Write([]byte(" ")); err != nil {
|
||||
return
|
||||
}
|
||||
case ev := <-listenCh:
|
||||
if err := enc.Encode(struct{ Records []event.Event }{[]event.Event{ev}}); err != nil {
|
||||
return
|
||||
}
|
||||
if len(listenCh) == 0 {
|
||||
// Flush if nothing is queued
|
||||
|
@ -34,7 +34,7 @@ import (
|
||||
"github.com/minio/minio/internal/http"
|
||||
xhttp "github.com/minio/minio/internal/http"
|
||||
"github.com/minio/minio/internal/logger"
|
||||
"github.com/minio/minio/internal/pubsub"
|
||||
"github.com/minio/minio/internal/logger/message/log"
|
||||
"github.com/minio/minio/internal/rest"
|
||||
xnet "github.com/minio/pkg/net"
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
@ -583,7 +583,7 @@ func (client *peerRESTClient) LoadTransitionTierConfig(ctx context.Context) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *peerRESTClient) doTrace(traceCh chan<- pubsub.Maskable, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) {
|
||||
func (client *peerRESTClient) doTrace(traceCh chan<- madmin.TraceInfo, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) {
|
||||
values := make(url.Values)
|
||||
traceOpts.AddParams(values)
|
||||
|
||||
@ -624,7 +624,7 @@ func (client *peerRESTClient) doTrace(traceCh chan<- pubsub.Maskable, doneCh <-c
|
||||
}
|
||||
}
|
||||
|
||||
func (client *peerRESTClient) doListen(listenCh chan<- pubsub.Maskable, doneCh <-chan struct{}, v url.Values) {
|
||||
func (client *peerRESTClient) doListen(listenCh chan<- event.Event, doneCh <-chan struct{}, v url.Values) {
|
||||
// To cancel the REST request in case doneCh gets closed.
|
||||
ctx, cancel := context.WithCancel(GlobalContext)
|
||||
|
||||
@ -663,7 +663,7 @@ func (client *peerRESTClient) doListen(listenCh chan<- pubsub.Maskable, doneCh <
|
||||
}
|
||||
|
||||
// Listen - listen on peers.
|
||||
func (client *peerRESTClient) Listen(listenCh chan<- pubsub.Maskable, doneCh <-chan struct{}, v url.Values) {
|
||||
func (client *peerRESTClient) Listen(listenCh chan<- event.Event, doneCh <-chan struct{}, v url.Values) {
|
||||
go func() {
|
||||
for {
|
||||
client.doListen(listenCh, doneCh, v)
|
||||
@ -679,7 +679,7 @@ func (client *peerRESTClient) Listen(listenCh chan<- pubsub.Maskable, doneCh <-c
|
||||
}
|
||||
|
||||
// Trace - send http trace request to peer nodes
|
||||
func (client *peerRESTClient) Trace(traceCh chan<- pubsub.Maskable, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) {
|
||||
func (client *peerRESTClient) Trace(traceCh chan<- madmin.TraceInfo, doneCh <-chan struct{}, traceOpts madmin.ServiceTraceOpts) {
|
||||
go func() {
|
||||
for {
|
||||
client.doTrace(traceCh, doneCh, traceOpts)
|
||||
@ -694,7 +694,7 @@ func (client *peerRESTClient) Trace(traceCh chan<- pubsub.Maskable, doneCh <-cha
|
||||
}()
|
||||
}
|
||||
|
||||
func (client *peerRESTClient) doConsoleLog(logCh chan pubsub.Maskable, doneCh <-chan struct{}) {
|
||||
func (client *peerRESTClient) doConsoleLog(logCh chan log.Info, doneCh <-chan struct{}) {
|
||||
// To cancel the REST request in case doneCh gets closed.
|
||||
ctx, cancel := context.WithCancel(GlobalContext)
|
||||
|
||||
@ -717,22 +717,20 @@ func (client *peerRESTClient) doConsoleLog(logCh chan pubsub.Maskable, doneCh <-
|
||||
|
||||
dec := gob.NewDecoder(respBody)
|
||||
for {
|
||||
var lg madmin.LogInfo
|
||||
var lg log.Info
|
||||
if err = dec.Decode(&lg); err != nil {
|
||||
break
|
||||
}
|
||||
if lg.DeploymentID != "" {
|
||||
select {
|
||||
case logCh <- lg:
|
||||
default:
|
||||
// Do not block on slow receivers.
|
||||
}
|
||||
select {
|
||||
case logCh <- lg:
|
||||
default:
|
||||
// Do not block on slow receivers.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConsoleLog - sends request to peer nodes to get console logs
|
||||
func (client *peerRESTClient) ConsoleLog(logCh chan pubsub.Maskable, doneCh <-chan struct{}) {
|
||||
func (client *peerRESTClient) ConsoleLog(logCh chan log.Info, doneCh <-chan struct{}) {
|
||||
go func() {
|
||||
for {
|
||||
client.doConsoleLog(logCh, doneCh)
|
||||
|
@ -35,6 +35,7 @@ import (
|
||||
b "github.com/minio/minio/internal/bucket/bandwidth"
|
||||
"github.com/minio/minio/internal/event"
|
||||
"github.com/minio/minio/internal/logger"
|
||||
"github.com/minio/minio/internal/logger/message/log"
|
||||
"github.com/minio/minio/internal/pubsub"
|
||||
"github.com/tinylib/msgp/msgp"
|
||||
)
|
||||
@ -919,13 +920,9 @@ func (s *peerRESTServer) ListenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Listen Publisher uses nonblocking publish and hence does not wait for slow subscribers.
|
||||
// Use buffered channel to take care of burst sends or slow w.Write()
|
||||
ch := make(chan pubsub.Maskable, 2000)
|
||||
ch := make(chan event.Event, 2000)
|
||||
|
||||
err := globalHTTPListen.Subscribe(mask, ch, doneCh, func(evI pubsub.Maskable) bool {
|
||||
ev, ok := evI.(event.Event)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
err := globalHTTPListen.Subscribe(mask, ch, doneCh, func(ev event.Event) bool {
|
||||
if ev.S3.Bucket.Name != "" && values.Get(peerRESTListenBucket) != "" {
|
||||
if ev.S3.Bucket.Name != values.Get(peerRESTListenBucket) {
|
||||
return false
|
||||
@ -978,13 +975,9 @@ func (s *peerRESTServer) TraceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Trace Publisher uses nonblocking publish and hence does not wait for slow subscribers.
|
||||
// Use buffered channel to take care of burst sends or slow w.Write()
|
||||
ch := make(chan pubsub.Maskable, 2000)
|
||||
mask := pubsub.MaskFromMaskable(traceOpts.TraceTypes())
|
||||
err = globalTrace.Subscribe(mask, ch, r.Context().Done(), func(entry pubsub.Maskable) bool {
|
||||
if e, ok := entry.(madmin.TraceInfo); ok {
|
||||
return shouldTrace(e, traceOpts)
|
||||
}
|
||||
return false
|
||||
ch := make(chan madmin.TraceInfo, 2000)
|
||||
err = globalTrace.Subscribe(traceOpts.TraceTypes(), ch, r.Context().Done(), func(entry madmin.TraceInfo) bool {
|
||||
return shouldTrace(entry, traceOpts)
|
||||
})
|
||||
if err != nil {
|
||||
s.writeErrorResponse(w, err)
|
||||
@ -1125,7 +1118,7 @@ func (s *peerRESTServer) ConsoleLogHandler(w http.ResponseWriter, r *http.Reques
|
||||
doneCh := make(chan struct{})
|
||||
defer close(doneCh)
|
||||
|
||||
ch := make(chan pubsub.Maskable, 2000)
|
||||
ch := make(chan log.Info, 2000)
|
||||
err := globalConsoleSys.Subscribe(ch, doneCh, "", 0, madmin.LogMaskAll, nil)
|
||||
if err != nil {
|
||||
s.writeErrorResponse(w, err)
|
||||
@ -1141,12 +1134,16 @@ func (s *peerRESTServer) ConsoleLogHandler(w http.ResponseWriter, r *http.Reques
|
||||
if err := enc.Encode(entry); err != nil {
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
case <-keepAliveTicker.C:
|
||||
if err := enc.Encode(&madmin.LogInfo{}); err != nil {
|
||||
return
|
||||
if len(ch) == 0 {
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
case <-keepAliveTicker.C:
|
||||
if len(ch) == 0 {
|
||||
if err := enc.Encode(&madmin.LogInfo{}); err != nil {
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
|
@ -67,6 +67,7 @@ const (
|
||||
ObjectReplicationAll
|
||||
ObjectRestorePostAll
|
||||
ObjectTransitionAll
|
||||
Everything
|
||||
)
|
||||
|
||||
// The number of single names should not exceed 64.
|
||||
@ -112,6 +113,12 @@ func (name Name) Expand() []Name {
|
||||
ObjectTransitionFailed,
|
||||
ObjectTransitionComplete,
|
||||
}
|
||||
case Everything:
|
||||
res := make([]Name, objectSingleTypesEnd-1)
|
||||
for i := range res {
|
||||
res[i] = Name(i + 1)
|
||||
}
|
||||
return res
|
||||
default:
|
||||
return []Name{name}
|
||||
}
|
||||
|
@ -24,28 +24,28 @@ import (
|
||||
)
|
||||
|
||||
// Sub - subscriber entity.
|
||||
type Sub struct {
|
||||
ch chan Maskable
|
||||
type Sub[T Maskable] struct {
|
||||
ch chan T
|
||||
types Mask
|
||||
filter func(entry Maskable) bool
|
||||
filter func(entry T) bool
|
||||
}
|
||||
|
||||
// PubSub holds publishers and subscribers
|
||||
type PubSub struct {
|
||||
type PubSub[T Maskable, M Maskable] struct {
|
||||
// atomics, keep at top:
|
||||
types uint64
|
||||
numSubscribers int32
|
||||
maxSubscribers int32
|
||||
|
||||
// not atomics:
|
||||
subs []*Sub
|
||||
subs []*Sub[T]
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// Publish message to the subscribers.
|
||||
// Note that publish is always nob-blocking send so that we don't block on slow receivers.
|
||||
// Hence receivers should use buffered channel so as not to miss the published events.
|
||||
func (ps *PubSub) Publish(item Maskable) {
|
||||
func (ps *PubSub[T, M]) Publish(item T) {
|
||||
ps.RLock()
|
||||
defer ps.RUnlock()
|
||||
for _, sub := range ps.subs {
|
||||
@ -59,7 +59,7 @@ func (ps *PubSub) Publish(item Maskable) {
|
||||
}
|
||||
|
||||
// Subscribe - Adds a subscriber to pubsub system
|
||||
func (ps *PubSub) Subscribe(mask Mask, subCh chan Maskable, doneCh <-chan struct{}, filter func(entry Maskable) bool) error {
|
||||
func (ps *PubSub[T, M]) Subscribe(mask M, subCh chan T, doneCh <-chan struct{}, filter func(entry T) bool) error {
|
||||
totalSubs := atomic.AddInt32(&ps.numSubscribers, 1)
|
||||
if ps.maxSubscribers > 0 && totalSubs > ps.maxSubscribers {
|
||||
atomic.AddInt32(&ps.numSubscribers, -1)
|
||||
@ -68,12 +68,12 @@ func (ps *PubSub) Subscribe(mask Mask, subCh chan Maskable, doneCh <-chan struct
|
||||
ps.Lock()
|
||||
defer ps.Unlock()
|
||||
|
||||
sub := &Sub{ch: subCh, types: mask, filter: filter}
|
||||
sub := &Sub[T]{ch: subCh, types: Mask(mask.Mask()), filter: filter}
|
||||
ps.subs = append(ps.subs, sub)
|
||||
|
||||
// We hold a lock, so we are safe to update
|
||||
combined := Mask(atomic.LoadUint64(&ps.types))
|
||||
combined.Merge(mask)
|
||||
combined.Merge(Mask(mask.Mask()))
|
||||
atomic.StoreUint64(&ps.types, uint64(combined))
|
||||
|
||||
go func() {
|
||||
@ -97,21 +97,23 @@ func (ps *PubSub) Subscribe(mask Mask, subCh chan Maskable, doneCh <-chan struct
|
||||
}
|
||||
|
||||
// NumSubscribers returns the number of current subscribers,
|
||||
// If t is non-nil, the type is checked against the active subscribed types,
|
||||
// and 0 will be returned if nobody is subscribed for the type,
|
||||
// otherwise the *total* number of subscribers is returned.
|
||||
func (ps *PubSub) NumSubscribers(m Maskable) int32 {
|
||||
if m != nil {
|
||||
types := Mask(atomic.LoadUint64(&ps.types))
|
||||
if !types.Overlaps(Mask(m.Mask())) {
|
||||
return 0
|
||||
}
|
||||
// The mask is checked against the active subscribed types,
|
||||
// and 0 will be returned if nobody is subscribed for the type(s).
|
||||
func (ps *PubSub[T, M]) NumSubscribers(mask M) int32 {
|
||||
types := Mask(atomic.LoadUint64(&ps.types))
|
||||
if !types.Overlaps(Mask(mask.Mask())) {
|
||||
return 0
|
||||
}
|
||||
return atomic.LoadInt32(&ps.numSubscribers)
|
||||
}
|
||||
|
||||
// Subscribers returns the number of current subscribers for all types.
|
||||
func (ps *PubSub[T, M]) Subscribers() int32 {
|
||||
return atomic.LoadInt32(&ps.numSubscribers)
|
||||
}
|
||||
|
||||
// New inits a PubSub system with a limit of maximum
|
||||
// subscribers unless zero is specified
|
||||
func New(maxSubscribers int32) *PubSub {
|
||||
return &PubSub{maxSubscribers: maxSubscribers}
|
||||
func New[T Maskable, M Maskable](maxSubscribers int32) *PubSub[T, M] {
|
||||
return &PubSub[T, M]{maxSubscribers: maxSubscribers}
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
ps := New(2)
|
||||
ps := New[Maskable, Mask](2)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
ch2 := make(chan Maskable, 1)
|
||||
doneCh := make(chan struct{})
|
||||
@ -38,21 +38,21 @@ func TestSubscribe(t *testing.T) {
|
||||
ps.Lock()
|
||||
defer ps.Unlock()
|
||||
|
||||
if len(ps.subs) != 2 || ps.NumSubscribers(nil) != 2 {
|
||||
if len(ps.subs) != 2 || ps.NumSubscribers(MaskAll) != 2 || ps.Subscribers() != 2 {
|
||||
t.Fatalf("expected 2 subscribers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumSubscribersMask(t *testing.T) {
|
||||
ps := New(2)
|
||||
ps := New[Maskable, Mask](2)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
ch2 := make(chan Maskable, 1)
|
||||
doneCh := make(chan struct{})
|
||||
defer close(doneCh)
|
||||
if err := ps.Subscribe(1, ch1, doneCh, nil); err != nil {
|
||||
if err := ps.Subscribe(Mask(1), ch1, doneCh, nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if err := ps.Subscribe(2, ch2, doneCh, nil); err != nil {
|
||||
if err := ps.Subscribe(Mask(2), ch2, doneCh, nil); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ps.Lock()
|
||||
@ -70,7 +70,7 @@ func TestNumSubscribersMask(t *testing.T) {
|
||||
if want, got := int32(2), ps.NumSubscribers(Mask(1|2)); got != want {
|
||||
t.Fatalf("want %d subscribers, got %d", want, got)
|
||||
}
|
||||
if want, got := int32(2), ps.NumSubscribers(nil); got != want {
|
||||
if want, got := int32(2), ps.NumSubscribers(MaskAll); got != want {
|
||||
t.Fatalf("want %d subscribers, got %d", want, got)
|
||||
}
|
||||
if want, got := int32(0), ps.NumSubscribers(Mask(4)); got != want {
|
||||
@ -79,7 +79,7 @@ func TestNumSubscribersMask(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSubscribeExceedingLimit(t *testing.T) {
|
||||
ps := New(2)
|
||||
ps := New[Maskable, Maskable](2)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
ch2 := make(chan Maskable, 1)
|
||||
ch3 := make(chan Maskable, 1)
|
||||
@ -97,7 +97,7 @@ func TestSubscribeExceedingLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
ps := New(2)
|
||||
ps := New[Maskable, Maskable](2)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
ch2 := make(chan Maskable, 1)
|
||||
doneCh1 := make(chan struct{})
|
||||
@ -127,7 +127,7 @@ func (m maskString) Mask() uint64 {
|
||||
}
|
||||
|
||||
func TestPubSub(t *testing.T) {
|
||||
ps := New(1)
|
||||
ps := New[Maskable, Maskable](1)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
doneCh1 := make(chan struct{})
|
||||
defer close(doneCh1)
|
||||
@ -143,7 +143,7 @@ func TestPubSub(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiPubSub(t *testing.T) {
|
||||
ps := New(2)
|
||||
ps := New[Maskable, Maskable](2)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
ch2 := make(chan Maskable, 1)
|
||||
doneCh := make(chan struct{})
|
||||
@ -165,7 +165,7 @@ func TestMultiPubSub(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiPubSubMask(t *testing.T) {
|
||||
ps := New(3)
|
||||
ps := New[Maskable, Maskable](3)
|
||||
ch1 := make(chan Maskable, 1)
|
||||
ch2 := make(chan Maskable, 1)
|
||||
ch3 := make(chan Maskable, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user