From 95895044a28bbfa54a46a5a2373cfdd28601ecba Mon Sep 17 00:00:00 2001 From: tommy <3405129587@qq.com> Date: Fri, 6 Mar 2020 19:40:40 +0800 Subject: [PATCH] add sub pub --- .gitignore | 4 +++ channel.go | 5 +++ envelope.go | 2 ++ examples/channel/main.go | 36 ++++++++++++++++++++++ hub.go | 51 +++++++++++++++++++++++++++++++ melody.go | 57 ++++++++++++++++++++++++++++++++++ session.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 234 insertions(+) create mode 100644 channel.go create mode 100644 examples/channel/main.go diff --git a/.gitignore b/.gitignore index 31ea56b..492e59d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ benchmark *.swp coverage.out Makefile +go.sum +go.mod +todo.md +.idea \ No newline at end of file diff --git a/channel.go b/channel.go new file mode 100644 index 0000000..d9e4a3e --- /dev/null +++ b/channel.go @@ -0,0 +1,5 @@ +package melody + +type Channel struct { + +} \ No newline at end of file diff --git a/envelope.go b/envelope.go index baa55a3..2cd2aea 100644 --- a/envelope.go +++ b/envelope.go @@ -4,4 +4,6 @@ type envelope struct { t int msg []byte filter filterFunc + // extends + c string } diff --git a/examples/channel/main.go b/examples/channel/main.go new file mode 100644 index 0000000..9de8265 --- /dev/null +++ b/examples/channel/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "log" + "melody" + + "github.com/gin-gonic/gin" +) + +func main() { + m := melody.New() + router := gin.Default() + router.GET("/chat", func(context *gin.Context) { + if err := m.HandleRequest(context.Writer, context.Request); err != nil { + log.Println(err) + } + }) + + m.HandleConnect(func(session *melody.Session) { + ch := session.Request.URL.Query().Get("channel") + if err := session.Subscribe(ch); err != nil { + log.Println(err) + } + }) + + m.HandleMessage(func(session *melody.Session, msg []byte) { + if err := session.Publish(msg); err != nil { + log.Println(err) + } + }) + + m.HandleSentMessage(func(session *melody.Session, bytes []byte) { + log.Printf("%+v", string(bytes)) + }) + router.Run(":8080") +} diff --git a/hub.go b/hub.go index edc6337..ea2d194 100644 --- a/hub.go +++ b/hub.go @@ -12,6 +12,11 @@ type hub struct { exit chan *envelope open bool rwmutex *sync.RWMutex + // extends + channels map[string]*Session + subscribe chan *Session + publish chan *envelope + unsubscribe chan *Session } func newHub() *hub { @@ -23,6 +28,11 @@ func newHub() *hub { exit: make(chan *envelope), open: true, rwmutex: &sync.RWMutex{}, + // extends + channels: make(map[string]*Session), + subscribe: make(chan *Session), + publish: make(chan *envelope), + unsubscribe: make(chan *Session), } } @@ -62,6 +72,47 @@ loop: h.open = false h.rwmutex.Unlock() break loop + + case s := <-h.subscribe: + // extends + h.rwmutex.Lock() + if _, ok := h.channels[s.channel]; !ok { + h.channels[s.channel] = s + } else { + h.channels[s.channel].prev = s // 原来的上一个指向现在的 + s.next = h.channels[s.channel] // 现在的下一个是原来的 + s.prev = nil // 成为队头 + h.channels[s.channel] = s // 现在的替代原来的位置 + } + h.rwmutex.Unlock() + case s := <-h.unsubscribe: + if _, ok := h.channels[s.channel]; ok { + h.rwmutex.Lock() + if s.next != nil { + s.next.prev = s.prev + } + if s.prev != nil { + s.prev.next = s.next + } else { + h.channels[s.channel] = s.next + } + s.channel = "" // 置空 + h.rwmutex.Unlock() + } + case m := <-h.publish: + h.rwmutex.RLock() + if _, ok := h.channels[m.c]; ok { + for s := h.channels[m.c]; s != nil; s = s.next { + if m.filter != nil { + if m.filter(s) { + s.writeMessage(m) + } + } else { + s.writeMessage(m) + } + } + } + h.rwmutex.RUnlock() } } } diff --git a/melody.go b/melody.go index d845876..50fcc40 100644 --- a/melody.go +++ b/melody.go @@ -181,6 +181,10 @@ func (m *Melody) HandleRequestWithKeys(w http.ResponseWriter, r *http.Request, k melody: m, open: true, rwmutex: &sync.RWMutex{}, + // extends + next: nil, + prev: nil, + channel: "", } m.hub.register <- session @@ -311,3 +315,56 @@ func (m *Melody) IsClosed() bool { func FormatCloseMessage(closeCode int, text string) []byte { return websocket.FormatCloseMessage(closeCode, text) } + +// extends +func (m *Melody) Subscribe(s *Session, c string) error { + if m.hub.closed() { + return errors.New("melody instance is already closed") + } + return s.Subscribe(c) +} + +func (m *Melody) Unsubscribe(s *Session, c string) error { + if m.hub.closed() { + return errors.New("melody instance is already closed") + } + return s.Subscribe(c) +} + +func (m *Melody) Publish(msg []byte, c string) error { + return m.PublishFilter(msg, c, nil) +} + +func (m *Melody) PublishOthers(msg []byte, c string, s *Session) error { + return m.PublishFilter(msg, c, func(session *Session) bool { + return s != session + }) +} + +func (m *Melody) PublishFilter(msg []byte, c string, fn func(*Session) bool) error { + if m.hub.closed() { + return errors.New("melody instance is already closed") + } + message := &envelope{t: websocket.TextMessage, msg: msg, c: c, filter: fn} + m.hub.publish <- message + return nil +} + +func (m *Melody) PublishBinary(msg []byte, c string) error { + return m.PublishBinaryFilter(msg, c, nil) +} + +func (m *Melody) PublishBinaryOthers(msg []byte, c string, s *Session) error { + return m.PublishFilter(msg, c, func(session *Session) bool { + return s != session + }) +} + +func (m *Melody) PublishBinaryFilter(msg []byte, c string, fn func(*Session) bool) error { + if m.hub.closed() { + return errors.New("melody instance is already closed") + } + message := &envelope{t: websocket.BinaryMessage, msg: msg, c: c, filter: fn} + m.hub.publish <- message + return nil +} diff --git a/session.go b/session.go index 3997cef..7ed6a1c 100644 --- a/session.go +++ b/session.go @@ -18,6 +18,10 @@ type Session struct { melody *Melody open bool rwmutex *sync.RWMutex + // extends + channel string + next *Session + prev *Session } func (s *Session) writeMessage(message *envelope) { @@ -217,3 +221,78 @@ func (s *Session) MustGet(key string) interface{} { func (s *Session) IsClosed() bool { return s.closed() } + +// extends +func (s *Session) Subscribe(c string) error { + // + if s.closed() { + return errors.New("tried to write to closed a session") + } + + if s.channel != "" && s.channel != c { // 存在订阅 并与接下来的channel无法共存 + return errors.New("session already subscribe channel") + } else { + s.channel = c + s.melody.hub.subscribe <- s + } + return nil +} + +func (s *Session) Unsubscribe() error { + if s.closed() { + return errors.New("tried to write to closed a session") + } + if s.channel == "" { + return errors.New("session not yet subscribe channel") + } else { + s.melody.hub.unsubscribe <- s + } + return nil +} + +func (s *Session) Publish(msg []byte) error { + return s.PublishFilter(msg, nil) +} + +func (s *Session) PublishFilter(msg []byte, fn func(*Session) bool) error { + if s.melody.hub.closed() { + return errors.New("melody instance is closed") + } + message := &envelope{t: websocket.TextMessage, msg: msg, c: s.channel, filter: fn} + s.melody.hub.publish <- message + return nil +} + +func (s *Session) PublishOthers(msg []byte) error { + return s.PublishFilter(msg, func(session *Session) bool { + return s != session + }) +} + +func (s *Session) PublishBinary(msg []byte) error { + return s.PublishBinaryFilter(msg, nil) +} + +func (s *Session) PublishBinaryOthers(msg []byte) error { + return s.PublishBinaryFilter(msg, func(session *Session) bool { + return s != session + }) +} + +func (s *Session) PublishBinaryFilter(msg []byte, fn func(*Session) bool) error { + if s.melody.hub.closed() { + return errors.New("melody instance is closed") + } + + message := &envelope{t: websocket.BinaryMessage, msg: msg, c: s.channel, filter: fn} + s.melody.hub.publish <- message + return nil +} + +func (s *Session) IsSubscribed() bool { + return s.channel != "" +} + +func (s *Session) SubscribeName() string { + return s.channel +}