diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..31ea56b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.DS_Store +benchmark +*.swp +coverage.out +Makefile diff --git a/.travis.yml b/.travis.yml index 11bbf38..0578b7a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,11 @@ language: go go: - - 1.4 +- 1.4 +before_install: + - go get github.com/axw/gocov/gocov + - go get github.com/mattn/goveralls + - if ! go get code.google.com/p/go.tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi install: - go get github.com/gorilla/websocket -script: go test -v +script: + - $HOME/gopath/bin/goveralls -service=travis-ci diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..dbbaaff --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,4 @@ +### 2015-06-10 + +* Support for binary messages. +* BroadcastOthers method. diff --git a/examples/chat/main.go b/examples/chat/main.go index c3e8c10..8030926 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -1,8 +1,8 @@ package main import ( - "../../" "github.com/gin-gonic/gin" + "github.com/olahol/melody" "net/http" ) diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index 9f63113..334a304 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -1,9 +1,9 @@ package main import ( - "../../" "github.com/gin-gonic/gin" "github.com/go-fsnotify/fsnotify" + "github.com/olahol/melody" "io/ioutil" "net/http" ) diff --git a/examples/multichat/main.go b/examples/multichat/main.go index 20c4acd..13a97c4 100644 --- a/examples/multichat/main.go +++ b/examples/multichat/main.go @@ -1,8 +1,8 @@ package main import ( - "../../" "github.com/gin-gonic/gin" + "github.com/olahol/melody" "net/http" ) diff --git a/melody.go b/melody.go index 3d0ce8a..1026a75 100644 --- a/melody.go +++ b/melody.go @@ -11,13 +11,14 @@ type handleSessionFunc func(*Session) type filterFunc func(*Session) bool type Melody struct { - Config *Config - Upgrader *websocket.Upgrader - messageHandler handleMessageFunc - errorHandler handleErrorFunc - connectHandler handleSessionFunc - disconnectHandler handleSessionFunc - hub *hub + Config *Config + Upgrader *websocket.Upgrader + messageHandler handleMessageFunc + messageHandlerBinary handleMessageFunc + errorHandler handleErrorFunc + connectHandler handleSessionFunc + disconnectHandler handleSessionFunc + hub *hub } // Returns a new melody instance. @@ -32,13 +33,14 @@ func New() *Melody { go hub.run() return &Melody{ - Config: newConfig(), - Upgrader: upgrader, - messageHandler: func(*Session, []byte) {}, - errorHandler: func(*Session, error) {}, - connectHandler: func(*Session) {}, - disconnectHandler: func(*Session) {}, - hub: hub, + Config: newConfig(), + Upgrader: upgrader, + messageHandler: func(*Session, []byte) {}, + messageHandlerBinary: func(*Session, []byte) {}, + errorHandler: func(*Session, error) {}, + connectHandler: func(*Session) {}, + disconnectHandler: func(*Session) {}, + hub: hub, } } @@ -57,6 +59,11 @@ func (m *Melody) HandleMessage(fn func(*Session, []byte)) { m.messageHandler = fn } +// Callback when a binary message comes in. +func (m *Melody) HandleMessageBinary(fn func(*Session, []byte)) { + m.messageHandlerBinary = fn +} + // Fires when a session has an error. func (m *Melody) HandleError(fn func(*Session, error)) { m.errorHandler = fn @@ -79,7 +86,7 @@ func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) { go session.writePump(m.errorHandler) - session.readPump(m.messageHandler, m.errorHandler) + session.readPump(m.messageHandler, m.messageHandlerBinary, m.errorHandler) m.hub.unregister <- session @@ -97,3 +104,10 @@ func (m *Melody) BroadcastFilter(msg []byte, fn func(*Session) bool) { message := &envelope{t: websocket.TextMessage, msg: msg, filter: fn} m.hub.broadcast <- message } + +// Broadcasts a message to all sessions except session `s`. +func (m *Melody) BroadcastOthers(msg []byte, s *Session) { + m.BroadcastFilter(msg, func(q *Session) bool { + return s != q + }) +} diff --git a/melody_test.go b/melody_test.go index dcbfe97..bd409de 100644 --- a/melody_test.go +++ b/melody_test.go @@ -77,6 +77,69 @@ func TestEcho(t *testing.T) { } } +func TestEchoBinary(t *testing.T) { + echo := NewTestServer() + echo.m.HandleMessageBinary(func(session *Session, msg []byte) { + session.WriteBinary(msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + fn := func(msg string) bool { + conn, err := NewDialer(server.URL) + defer conn.Close() + + if err != nil { + t.Error(err) + return false + } + + conn.WriteMessage(websocket.BinaryMessage, []byte(msg)) + + _, ret, err := conn.ReadMessage() + + if err != nil { + t.Error(err) + return false + } + + if msg != string(ret) { + t.Errorf("%s should equal %s", msg, string(ret)) + return false + } + + return true + } + + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} + +func TestHandlers(t *testing.T) { + echo := NewTestServer() + echo.m.HandleMessage(func(session *Session, msg []byte) { + session.Write(msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + var q *Session + + echo.m.HandleConnect(func(session *Session) { + q = session + session.Close() + }) + + echo.m.HandleDisconnect(func(session *Session) { + if q != session { + t.Error("disconnecting session should be the same as connecting") + } + }) + + NewDialer(server.URL) +} + func TestUpgrader(t *testing.T) { broadcast := NewTestServer() broadcast.m.HandleMessage(func(session *Session, msg []byte) { @@ -141,6 +204,65 @@ func TestBroadcast(t *testing.T) { } } + _, ret, err := conn.ReadMessage() + + if err != nil { + t.Error(err) + return false + } + + if msg != string(ret) { + t.Errorf("%s should equal %s", msg, string(ret)) + return false + } + + return true + } + + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} + +func TestBroadcastOthers(t *testing.T) { + broadcast := NewTestServer() + broadcast.m.HandleMessage(func(session *Session, msg []byte) { + broadcast.m.BroadcastOthers(msg, session) + }) + broadcast.m.Config.PongWait = time.Second + broadcast.m.Config.PingPeriod = time.Second * 9 / 10 + server := httptest.NewServer(broadcast) + defer server.Close() + + n := 10 + + fn := func(msg string) bool { + conn, _ := NewDialer(server.URL) + defer conn.Close() + + listeners := make([]*websocket.Conn, n) + for i := 0; i < n; i++ { + listener, _ := NewDialer(server.URL) + listeners[i] = listener + defer listeners[i].Close() + } + + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + + for i := 0; i < n; i++ { + _, ret, err := listeners[i].ReadMessage() + + if err != nil { + t.Error(err) + return false + } + + if msg != string(ret) { + t.Errorf("%s should equal %s", msg, string(ret)) + return false + } + } + return true } @@ -175,23 +297,18 @@ func TestPingPong(t *testing.T) { } } -/* func TestBroadcastFilter(t *testing.T) { - echo := NewTestServer() - echo.m.HandleMessage(func(session *Session, msg []byte) { - echo.m.BroadcastFilter(func(s *Session) bool { - //return s == session - return false - }, msg) + broadcast := NewTestServer() + broadcast.m.HandleMessage(func(session *Session, msg []byte) { + broadcast.m.BroadcastFilter(msg, func(q *Session) bool { + return session == q + }) }) - server := httptest.NewServer(echo) + server := httptest.NewServer(broadcast) defer server.Close() fn := func(msg string) bool { conn, err := NewDialer(server.URL) - conn.SetPingHandler(func(string) error { - return nil - }) defer conn.Close() if err != nil { @@ -220,47 +337,3 @@ func TestBroadcastFilter(t *testing.T) { t.Error(err) } } -*/ - -func BenchmarkEcho(b *testing.B) { - echo := NewTestServerHandler(func(session *Session, msg []byte) { - session.Write(msg) - }) - server := httptest.NewServer(echo) - defer server.Close() - - conn, _ := NewDialer(server.URL) - defer conn.Close() - - for i := 0; i < b.N; i++ { - conn.WriteMessage(websocket.TextMessage, []byte("test")) - conn.ReadMessage() - } -} - -func BenchmarkBroadcast(b *testing.B) { - broadcast := NewTestServer() - broadcast.m.HandleMessage(func(session *Session, msg []byte) { - broadcast.m.Broadcast(msg) - }) - server := httptest.NewServer(broadcast) - defer server.Close() - - conn, _ := NewDialer(server.URL) - defer conn.Close() - - n := 10 - listeners := make([]*websocket.Conn, n) - for i := 0; i < n; i++ { - listener, _ := NewDialer(server.URL) - listeners[i] = listener - defer listeners[i].Close() - } - - for i := 0; i < b.N; i++ { - conn.WriteMessage(websocket.TextMessage, []byte("test")) - for i := 0; i < n; i++ { - listeners[i].ReadMessage() - } - } -} diff --git a/session.go b/session.go index 3f4e471..0c890a6 100644 --- a/session.go +++ b/session.go @@ -29,7 +29,21 @@ func (s *Session) writeMessage(message *envelope) { func (s *Session) writeRaw(message *envelope) error { s.conn.SetWriteDeadline(time.Now().Add(s.config.WriteWait)) - return s.conn.WriteMessage(message.t, message.msg) + err := s.conn.WriteMessage(message.t, message.msg) + + if err != nil { + return err + } + + if message.t == websocket.CloseMessage { + err := s.conn.Close() + + if err != nil { + return err + } + } + + return nil } func (s *Session) close() { @@ -63,7 +77,7 @@ func (s *Session) writePump(errorHandler handleErrorFunc) { } } -func (s *Session) readPump(messageHandler handleMessageFunc, errorHandler handleErrorFunc) { +func (s *Session) readPump(messageHandler handleMessageFunc, messageHandlerBinary handleMessageFunc, errorHandler handleErrorFunc) { defer s.conn.Close() s.conn.SetReadLimit(s.config.MaxMessageSize) @@ -75,18 +89,29 @@ func (s *Session) readPump(messageHandler handleMessageFunc, errorHandler handle }) for { - _, message, err := s.conn.ReadMessage() + t, message, err := s.conn.ReadMessage() if err != nil { go errorHandler(s, err) break } - go messageHandler(s, message) + if t == websocket.TextMessage { + go messageHandler(s, message) + } + + if t == websocket.BinaryMessage { + go messageHandlerBinary(s, message) + } } } // Write message to session. +func (s *Session) WriteBinary(msg []byte) { + s.writeMessage(&envelope{t: websocket.BinaryMessage, msg: msg}) +} + +// Write message to session. func (s *Session) Write(msg []byte) { s.writeMessage(&envelope{t: websocket.TextMessage, msg: msg}) }