diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b18272..e31f26d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 2017-01-20 + +* Add `Len()` to fetch number of connected sessions. + ## 2016-12-09 * Add metadata management for sessions. diff --git a/hub.go b/hub.go index 097dbaa..85c51a9 100644 --- a/hub.go +++ b/hub.go @@ -1,5 +1,9 @@ package melody +import ( + "sync" +) + type hub struct { sessions map[*Session]bool broadcast chan *envelope @@ -7,6 +11,7 @@ type hub struct { unregister chan *Session exit chan bool open bool + rwmutex *sync.RWMutex } func newHub() *hub { @@ -17,6 +22,7 @@ func newHub() *hub { unregister: make(chan *Session), exit: make(chan bool), open: true, + rwmutex: &sync.RWMutex{}, } } @@ -25,14 +31,19 @@ loop: for { select { case s := <-h.register: + h.rwmutex.Lock() h.sessions[s] = true + h.rwmutex.Unlock() case s := <-h.unregister: if _, ok := h.sessions[s]; ok { + h.rwmutex.Lock() delete(h.sessions, s) s.conn.Close() close(s.output) + h.rwmutex.Unlock() } case m := <-h.broadcast: + h.rwmutex.RLock() for s := range h.sessions { if m.filter != nil { if m.filter(s) { @@ -42,14 +53,24 @@ loop: s.writeMessage(m) } } + h.rwmutex.RUnlock() case <-h.exit: + h.rwmutex.Lock() for s := range h.sessions { delete(h.sessions, s) s.conn.Close() close(s.output) } h.open = false + h.rwmutex.Unlock() break loop } } } + +func (h *hub) len() int { + h.rwmutex.RLock() + defer h.rwmutex.RUnlock() + + return len(h.sessions) +} diff --git a/melody.go b/melody.go index 26002f0..0ec990a 100644 --- a/melody.go +++ b/melody.go @@ -153,3 +153,8 @@ func (m *Melody) BroadcastBinaryOthers(msg []byte, s *Session) { func (m *Melody) Close() { m.hub.exit <- true } + +// Len return the number of connected sessions. +func (m *Melody) Len() int { + return m.hub.len() +} diff --git a/melody_test.go b/melody_test.go index fa3ba71..7a68527 100644 --- a/melody_test.go +++ b/melody_test.go @@ -3,6 +3,7 @@ package melody import ( "bytes" "github.com/gorilla/websocket" + "math/rand" "net/http" "net/http/httptest" "strconv" @@ -79,6 +80,51 @@ func TestEcho(t *testing.T) { } } +func TestLen(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + + connect := int(rand.Int31n(1000)) + disconnect := rand.Float32() + conns := make([]*websocket.Conn, connect) + defer func() { + for _, conn := range conns { + if conn != nil { + conn.Close() + } + } + }() + + echo := NewTestServerHandler(func(session *Session, msg []byte) {}) + server := httptest.NewServer(echo) + defer server.Close() + + disconnected := 0 + for i := 0; i < connect; i++ { + conn, err := NewDialer(server.URL) + + if err != nil { + t.Error(err) + } + + if rand.Float32() < disconnect { + conns[i] = nil + disconnected += 1 + conn.Close() + continue + } + + conns[i] = conn + } + + time.Sleep(time.Millisecond) + + connected := connect - disconnected + + if echo.m.Len() != connected { + t.Errorf("melody len %d should equal %d", echo.m.Len(), connected) + } +} + func TestEchoBinary(t *testing.T) { echo := NewTestServer() echo.m.HandleMessageBinary(func(session *Session, msg []byte) {