From ee4453a90447f924351f337b4b58b228f7e22735 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ola=20Holmstr=C3=B6m?= Date: Fri, 18 Sep 2015 14:50:46 +0200 Subject: [PATCH] ensure order of messages and dispatch error handler when buffer is full --- config.go | 2 +- hub.go | 4 ++-- melody.go | 11 ++++++++--- melody_test.go | 23 +++++++++++++++++++++++ session.go | 40 ++++++++++++++++++---------------------- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/config.go b/config.go index 42ae530..3f7f880 100644 --- a/config.go +++ b/config.go @@ -8,7 +8,7 @@ type Config struct { PongWait time.Duration // Timeout for waiting on pong. PingPeriod time.Duration // Milliseconds between pings. MaxMessageSize int64 // Maximum size in bytes of a message. - MessageBufferSize int // Size of each sessions message buffer. + MessageBufferSize int // The max amount of messages that can be in a sessions buffer before it starts dropping them. } func newConfig() *Config { diff --git a/hub.go b/hub.go index 63ed440..097dbaa 100644 --- a/hub.go +++ b/hub.go @@ -36,10 +36,10 @@ loop: for s := range h.sessions { if m.filter != nil { if m.filter(s) { - go s.writeMessage(m) + s.writeMessage(m) } } else { - go s.writeMessage(m) + s.writeMessage(m) } } case <-h.exit: diff --git a/melody.go b/melody.go index 61ef4fd..1a72f24 100644 --- a/melody.go +++ b/melody.go @@ -78,15 +78,20 @@ func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) { return } - session := newSession(m.Config, conn, r) + session := &Session{ + Request: r, + conn: conn, + output: make(chan *envelope, m.Config.MessageBufferSize), + melody: m, + } m.hub.register <- session go m.connectHandler(session) - go session.writePump(m.errorHandler) + go session.writePump() - session.readPump(m.messageHandler, m.messageHandlerBinary, m.errorHandler) + session.readPump() if m.hub.open { m.hub.unregister <- session diff --git a/melody_test.go b/melody_test.go index e03f750..334bad4 100644 --- a/melody_test.go +++ b/melody_test.go @@ -340,3 +340,26 @@ func TestStop(t *testing.T) { noecho.m.Close() } + +func TestSmallMessageBuffer(t *testing.T) { + echo := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + }) + echo.m.Config.MessageBufferSize = 0 + echo.m.HandleError(func(s *Session, err error) { + if err == nil { + t.Error("there should be a buffer full error here") + } + }) + server := httptest.NewServer(echo) + defer server.Close() + + conn, err := NewDialer(server.URL) + defer conn.Close() + + if err != nil { + t.Error(err) + } + + conn.WriteMessage(websocket.TextMessage, []byte("12345")) +} diff --git a/session.go b/session.go index e84b603..e0f5996 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package melody import ( + "errors" "github.com/gorilla/websocket" "net/http" "time" @@ -11,24 +12,19 @@ type Session struct { Request *http.Request conn *websocket.Conn output chan *envelope - config *Config -} - -func newSession(config *Config, conn *websocket.Conn, req *http.Request) *Session { - return &Session{ - Request: req, - conn: conn, - output: make(chan *envelope, config.MessageBufferSize), - config: config, - } + melody *Melody } func (s *Session) writeMessage(message *envelope) { - s.output <- message + if len(s.output) < s.melody.Config.MessageBufferSize { + s.output <- message + } else { + s.melody.errorHandler(s, errors.New("Message buffer full")) + } } func (s *Session) writeRaw(message *envelope) error { - s.conn.SetWriteDeadline(time.Now().Add(s.config.WriteWait)) + s.conn.SetWriteDeadline(time.Now().Add(s.melody.Config.WriteWait)) err := s.conn.WriteMessage(message.t, message.msg) if err != nil { @@ -54,10 +50,10 @@ func (s *Session) ping() { s.writeMessage(&envelope{t: websocket.PingMessage, msg: []byte{}}) } -func (s *Session) writePump(errorHandler handleErrorFunc) { +func (s *Session) writePump() { defer s.conn.Close() - ticker := time.NewTicker(s.config.PingPeriod) + ticker := time.NewTicker(s.melody.Config.PingPeriod) defer ticker.Stop() loop: @@ -69,7 +65,7 @@ loop: break loop } if err := s.writeRaw(msg); err != nil { - go errorHandler(s, err) + s.melody.errorHandler(s, err) break loop } case <-ticker.C: @@ -78,14 +74,14 @@ loop: } } -func (s *Session) readPump(messageHandler handleMessageFunc, messageHandlerBinary handleMessageFunc, errorHandler handleErrorFunc) { +func (s *Session) readPump() { defer s.conn.Close() - s.conn.SetReadLimit(s.config.MaxMessageSize) - s.conn.SetReadDeadline(time.Now().Add(s.config.PongWait)) + s.conn.SetReadLimit(s.melody.Config.MaxMessageSize) + s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait)) s.conn.SetPongHandler(func(string) error { - s.conn.SetReadDeadline(time.Now().Add(s.config.PongWait)) + s.conn.SetReadDeadline(time.Now().Add(s.melody.Config.PongWait)) return nil }) @@ -93,16 +89,16 @@ func (s *Session) readPump(messageHandler handleMessageFunc, messageHandlerBinar t, message, err := s.conn.ReadMessage() if err != nil { - go errorHandler(s, err) + s.melody.errorHandler(s, err) break } if t == websocket.TextMessage { - go messageHandler(s, message) + s.melody.messageHandler(s, message) } if t == websocket.BinaryMessage { - go messageHandlerBinary(s, message) + s.melody.messageHandlerBinary(s, message) } } }