From 3321ec3da72152d5003846b77322e2db0ef562ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ola=20Holmstr=C3=B6m?= Date: Fri, 10 Feb 2017 23:12:27 +0100 Subject: [PATCH] Return error messages for some exposed methods, fix panic when connection is closed. Close #21. --- hub.go | 11 +++++--- melody.go | 69 ++++++++++++++++++++++++++++++++++++----------- melody_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ session.go | 84 +++++++++++++++++++++++++++++++++++++++------------------- 4 files changed, 199 insertions(+), 46 deletions(-) diff --git a/hub.go b/hub.go index 85c51a9..fba1653 100644 --- a/hub.go +++ b/hub.go @@ -38,8 +38,6 @@ loop: if _, ok := h.sessions[s]; ok { h.rwmutex.Lock() delete(h.sessions, s) - s.conn.Close() - close(s.output) h.rwmutex.Unlock() } case m := <-h.broadcast: @@ -58,8 +56,7 @@ loop: h.rwmutex.Lock() for s := range h.sessions { delete(h.sessions, s) - s.conn.Close() - close(s.output) + s.Close() } h.open = false h.rwmutex.Unlock() @@ -68,6 +65,12 @@ loop: } } +func (h *hub) closed() bool { + h.rwmutex.RLock() + defer h.rwmutex.RUnlock() + return !h.open +} + func (h *hub) len() int { h.rwmutex.RLock() defer h.rwmutex.RUnlock() diff --git a/melody.go b/melody.go index 0ec990a..2af1624 100644 --- a/melody.go +++ b/melody.go @@ -1,6 +1,7 @@ package melody import ( + "errors" "github.com/gorilla/websocket" "net/http" "sync" @@ -79,12 +80,15 @@ func (m *Melody) HandleError(fn func(*Session, error)) { } // HandleRequest upgrades http requests to websocket connections and dispatches them to be handled by the melody instance. -func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) { +func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) error { + if m.hub.closed() { + return errors.New("Melody instance is closed.") + } + conn, err := m.Upgrader.Upgrade(w, r, nil) if err != nil { - m.errorHandler(nil, err) - return + return err } session := &Session{ @@ -93,7 +97,8 @@ func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) { conn: conn, output: make(chan *envelope, m.Config.MessageBufferSize), melody: m, - lock: &sync.Mutex{}, + open: true, + rwmutex: &sync.RWMutex{}, } m.hub.register <- session @@ -104,54 +109,88 @@ func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) { session.readPump() - if m.hub.open { + if !m.hub.closed() { m.hub.unregister <- session } - go m.disconnectHandler(session) + session.close() + + m.disconnectHandler(session) + + return nil } // Broadcast broadcasts a text message to all sessions. -func (m *Melody) Broadcast(msg []byte) { +func (m *Melody) Broadcast(msg []byte) error { + if m.hub.closed() { + return errors.New("Melody instance is closed.") + } + message := &envelope{t: websocket.TextMessage, msg: msg} m.hub.broadcast <- message + + return nil } // BroadcastFilter broadcasts a text message to all sessions that fn returns true for. -func (m *Melody) BroadcastFilter(msg []byte, fn func(*Session) bool) { +func (m *Melody) BroadcastFilter(msg []byte, fn func(*Session) bool) error { + if m.hub.closed() { + return errors.New("Melody instance is closed.") + } + message := &envelope{t: websocket.TextMessage, msg: msg, filter: fn} m.hub.broadcast <- message + + return nil } // BroadcastOthers broadcasts a text message to all sessions except session s. -func (m *Melody) BroadcastOthers(msg []byte, s *Session) { - m.BroadcastFilter(msg, func(q *Session) bool { +func (m *Melody) BroadcastOthers(msg []byte, s *Session) error { + return m.BroadcastFilter(msg, func(q *Session) bool { return s != q }) } // BroadcastBinary broadcasts a binary message to all sessions. -func (m *Melody) BroadcastBinary(msg []byte) { +func (m *Melody) BroadcastBinary(msg []byte) error { + if m.hub.closed() { + return errors.New("Melody instance is closed.") + } + message := &envelope{t: websocket.BinaryMessage, msg: msg} m.hub.broadcast <- message + + return nil } // BroadcastBinaryFilter broadcasts a binary message to all sessions that fn returns true for. -func (m *Melody) BroadcastBinaryFilter(msg []byte, fn func(*Session) bool) { +func (m *Melody) BroadcastBinaryFilter(msg []byte, fn func(*Session) bool) error { + if m.hub.closed() { + return errors.New("Melody instance is closed.") + } + message := &envelope{t: websocket.BinaryMessage, msg: msg, filter: fn} m.hub.broadcast <- message + + return nil } // BroadcastBinaryOthers broadcasts a binary message to all sessions except session s. -func (m *Melody) BroadcastBinaryOthers(msg []byte, s *Session) { - m.BroadcastBinaryFilter(msg, func(q *Session) bool { +func (m *Melody) BroadcastBinaryOthers(msg []byte, s *Session) error { + return m.BroadcastBinaryFilter(msg, func(q *Session) bool { return s != q }) } // Close closes the melody instance and all connected sessions. -func (m *Melody) Close() { +func (m *Melody) Close() error { + if m.hub.closed() { + return errors.New("Melody instance is already closed.") + } + m.hub.exit <- true + + return nil } // Len return the number of connected sessions. diff --git a/melody_test.go b/melody_test.go index fb31d2d..23a1bd7 100644 --- a/melody_test.go +++ b/melody_test.go @@ -80,6 +80,43 @@ func TestEcho(t *testing.T) { } } +func TestWriteClosed(t *testing.T) { + echo := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + fn := func(msg string) bool { + conn, err := NewDialer(server.URL) + + if err != nil { + t.Error(err) + return false + } + + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + + echo.m.HandleConnect(func(s *Session) { + s.Close() + }) + + echo.m.HandleDisconnect(func(s *Session) { + err := s.Write([]byte("hello world")) + + if err == nil { + t.Error("should be an error") + } + }) + + return true + } + + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} + func TestLen(t *testing.T) { rand.Seed(time.Now().UnixNano()) @@ -641,3 +678,47 @@ func TestPong(t *testing.T) { t.Error("should have fired pong handler") } } + +func BenchmarkSessionWrite(b *testing.B) { + echo := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + }) + server := httptest.NewServer(echo) + conn, _ := NewDialer(server.URL) + defer server.Close() + defer conn.Close() + + for n := 0; n < b.N; n++ { + conn.WriteMessage(websocket.TextMessage, []byte("test")) + conn.ReadMessage() + } +} + +func BenchmarkBroadcast(b *testing.B) { + echo := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + conns := make([]*websocket.Conn, 0) + + num := 100 + + for i := 0; i < num; i++ { + conn, _ := NewDialer(server.URL) + conns = append(conns, conn) + } + + for n := 0; n < b.N; n++ { + echo.m.Broadcast([]byte("test")) + + for i := 0; i < num; i++ { + conns[i].ReadMessage() + } + } + + for i := 0; i < num; i++ { + conns[i].Close() + } +} diff --git a/session.go b/session.go index 1c02769..6b60857 100644 --- a/session.go +++ b/session.go @@ -8,25 +8,35 @@ import ( "time" ) -// Session is wrapper around websocket connections. +// Session wrapper around websocket connections. type Session struct { Request *http.Request Keys map[string]interface{} conn *websocket.Conn output chan *envelope melody *Melody - lock *sync.Mutex + open bool + rwmutex *sync.RWMutex } func (s *Session) writeMessage(message *envelope) { + if s.closed() { + s.melody.errorHandler(s, errors.New("Tried to write to closed a session.")) + return + } + select { case s.output <- message: default: - s.melody.errorHandler(s, errors.New("Message buffer full")) + s.melody.errorHandler(s, errors.New("Session message buffer is full.")) } } func (s *Session) writeRaw(message *envelope) error { + if s.closed() { + return errors.New("Trie to write to a closed session.") + } + s.conn.SetWriteDeadline(time.Now().Add(s.melody.Config.WriteWait)) err := s.conn.WriteMessage(message.t, message.msg) @@ -34,19 +44,24 @@ func (s *Session) writeRaw(message *envelope) error { return err } - if message.t == websocket.CloseMessage { - err := s.conn.Close() + return nil +} - if err != nil { - return err - } - } +func (s *Session) closed() bool { + s.rwmutex.RLock() + defer s.rwmutex.RUnlock() - return nil + return !s.open } func (s *Session) close() { - s.writeRaw(&envelope{t: websocket.CloseMessage, msg: []byte{}}) + if !s.closed() { + s.rwmutex.Lock() + s.open = false + s.conn.Close() + close(s.output) + s.rwmutex.Unlock() + } } func (s *Session) ping() { @@ -54,8 +69,6 @@ func (s *Session) ping() { } func (s *Session) writePump() { - defer s.conn.Close() - ticker := time.NewTicker(s.melody.Config.PingPeriod) defer ticker.Stop() @@ -64,13 +77,20 @@ loop: select { case msg, ok := <-s.output: if !ok { - s.close() break loop } - if err := s.writeRaw(msg); err != nil { + + err := s.writeRaw(msg) + + if err != nil { s.melody.errorHandler(s, err) break loop } + + if msg.t == websocket.CloseMessage { + break loop + } + case <-ticker.C: s.ping() } @@ -78,8 +98,6 @@ loop: } func (s *Session) readPump() { - defer s.conn.Close() - s.conn.SetReadLimit(s.melody.Config.MaxMessageSize) s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait)) @@ -108,26 +126,41 @@ func (s *Session) readPump() { } // Write writes message to session. -func (s *Session) Write(msg []byte) { +func (s *Session) Write(msg []byte) error { + if s.closed() { + return errors.New("Session is closed.") + } + s.writeMessage(&envelope{t: websocket.TextMessage, msg: msg}) + + return nil } // WriteBinary writes a binary message to session. -func (s *Session) WriteBinary(msg []byte) { +func (s *Session) WriteBinary(msg []byte) error { + if s.closed() { + return errors.New("Session is closed.") + } + s.writeMessage(&envelope{t: websocket.BinaryMessage, msg: msg}) + + return nil } -// Close closes a session. -func (s *Session) Close() { +// Close closes session. +func (s *Session) Close() error { + if s.closed() { + return errors.New("Session is already closed.") + } + s.writeMessage(&envelope{t: websocket.CloseMessage, msg: []byte{}}) + + return nil } // Set is used to store a new key/value pair exclusivelly for this session. // It also lazy initializes s.Keys if it was not used previously. func (s *Session) Set(key string, value interface{}) { - s.lock.Lock() - defer s.lock.Unlock() - if s.Keys == nil { s.Keys = make(map[string]interface{}) } @@ -138,9 +171,6 @@ func (s *Session) Set(key string, value interface{}) { // Get returns the value for the given key, ie: (value, true). // If the value does not exists it returns (nil, false) func (s *Session) Get(key string) (value interface{}, exists bool) { - s.lock.Lock() - defer s.lock.Unlock() - if s.Keys != nil { value, exists = s.Keys[key] }