Codebase list golang-github-bsm-redeo / HEAD pubsub.go
HEAD

Tree @HEAD (Download .tar.gz)

pubsub.go @HEADraw · history · blame

package redeo

import (
	"sync"
	"sync/atomic"

	"github.com/bsm/redeo/resp"
)

// PubSubBroker can be used to emulate redis'
// native pub/sub functionality
type PubSubBroker struct {
	channels map[string]*pubSubChannel
	mu       sync.RWMutex
}

// NewPubSubBroker inits a new pub-sub broker
func NewPubSubBroker() *PubSubBroker {
	return &PubSubBroker{
		channels: make(map[string]*pubSubChannel),
	}
}

// Subscribe returns a subscribe handler
func (b *PubSubBroker) Subscribe() Handler {
	return HandlerFunc(func(w resp.ResponseWriter, c *resp.Command) {
		if c.ArgN() != 1 {
			w.AppendError(WrongNumberOfArgs(c.Name))
			return
		}
		b.subscribe(c.Arg(0).String(), w)
	})
}

// Publish acts as a publish handler
func (b *PubSubBroker) Publish() Handler {
	return HandlerFunc(func(w resp.ResponseWriter, c *resp.Command) {
		if c.ArgN() != 2 {
			w.AppendError(WrongNumberOfArgs(c.Name))
			return
		}

		n := b.PublishMessage(c.Arg(0).String(), c.Arg(1).String())
		w.AppendInt(n)
	})
}

// PublishMessage allows to publish a message to the broker
// outside the command-cycle. Returns the number of subscribers
func (b *PubSubBroker) PublishMessage(name, msg string) int64 {
	b.mu.RLock()
	ch, ok := b.channels[name]
	b.mu.RUnlock()

	if ok {
		return ch.Publish(name, msg)
	}
	return 0
}

func (b *PubSubBroker) subscribe(name string, w resp.ResponseWriter) {
	b.mu.RLock()
	ch, ok := b.channels[name]
	b.mu.RUnlock()

	if !ok {
		b.mu.Lock()
		if ch, ok = b.channels[name]; !ok {
			ch = &pubSubChannel{
				subscribers: make(map[int64]resp.ResponseWriter),
			}
			b.channels[name] = ch
		}
		b.mu.Unlock()
	}

	ch.Subscribe(w)
	w.AppendArrayLen(3)
	w.AppendBulkString("subscribe")
	w.AppendBulkString(name)
	w.AppendInt(1)
}

// --------------------------------------------------------------------

type pubSubChannel struct {
	subscribers map[int64]resp.ResponseWriter
	mu          sync.RWMutex
	nextID      int64
}

func (c *pubSubChannel) Subscribe(w resp.ResponseWriter) {
	sid := atomic.AddInt64(&c.nextID, 1)

	c.mu.Lock()
	c.subscribers[sid] = w
	c.mu.Unlock()
}

func (c *pubSubChannel) Publish(name, msg string) (n int64) {
	var failed []int64

	c.mu.RLock()
	for sid, w := range c.subscribers {
		w.AppendArrayLen(3)
		w.AppendBulkString("message")
		w.AppendBulkString(name)
		w.AppendBulkString(msg)

		if err := w.Flush(); err != nil {
			failed = append(failed, sid)
		} else {
			n++
		}
	}
	c.mu.RUnlock()

	if len(failed) != 0 {
		c.evict(failed)
	}
	return
}

func (c *pubSubChannel) evict(failed []int64) {
	c.mu.Lock()
	for _, sid := range failed {
		delete(c.subscribers, sid)
	}
	c.mu.Unlock()
}