commit fb9197b6ff29e5948bd758ba0a4d426aa13e4c74 Author: Ola Holmström Date: Wed May 13 22:37:17 2015 +0200 first version diff --git a/README.md b/README.md new file mode 100644 index 0000000..3154312 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# melody + +> Simple websocket framework for Go diff --git a/config.go b/config.go new file mode 100644 index 0000000..84ae834 --- /dev/null +++ b/config.go @@ -0,0 +1,21 @@ +package melody + +import "time" + +type Config struct { + WriteWait time.Duration + PongWait time.Duration + PingPeriod time.Duration + MaxMessageSize int64 + MessageBufferSize int +} + +func newConfig() *Config { + return &Config{ + WriteWait: 10 * time.Second, + PongWait: 60 * time.Second, + PingPeriod: (60 * time.Second * 9) / 10, + MaxMessageSize: 512, + MessageBufferSize: 256, + } +} diff --git a/envelope.go b/envelope.go new file mode 100644 index 0000000..baa55a3 --- /dev/null +++ b/envelope.go @@ -0,0 +1,7 @@ +package melody + +type envelope struct { + t int + msg []byte + filter filterFunc +} diff --git a/examples/chat/index.html b/examples/chat/index.html new file mode 100644 index 0000000..eac2a95 --- /dev/null +++ b/examples/chat/index.html @@ -0,0 +1,50 @@ + + + Melody example: chatting + + + + + +
+

Chat

+

+      
+    
+ + + + diff --git a/examples/chat/main.go b/examples/chat/main.go new file mode 100644 index 0000000..b136ab9 --- /dev/null +++ b/examples/chat/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "../../" + "github.com/gin-gonic/gin" + "net/http" +) + +func main() { + r := gin.Default() + m := melody.Default() + + r.GET("/", func(c *gin.Context) { + http.ServeFile(c.Writer, c.Request, "index.html") + }) + + r.GET("/ws", func(c *gin.Context) { + m.HandleRequest(c.Writer, c.Request) + }) + + m.HandleMessage(func(s *melody.Session, msg []byte) { + m.Broadcast(msg) + }) + + r.Run(":5000") +} diff --git a/examples/filewatch/file.txt b/examples/filewatch/file.txt new file mode 100644 index 0000000..0b70aef --- /dev/null +++ b/examples/filewatch/file.txt @@ -0,0 +1 @@ +Hello World! World Earl diff --git a/examples/filewatch/index.html b/examples/filewatch/index.html new file mode 100644 index 0000000..6276a4c --- /dev/null +++ b/examples/filewatch/index.html @@ -0,0 +1,32 @@ + + + Melody example: file watching + + + + + +
+

Watching a file

+

+    
+ + + + diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go new file mode 100644 index 0000000..f034293 --- /dev/null +++ b/examples/filewatch/main.go @@ -0,0 +1,44 @@ +package main + +import ( + "../../" + "github.com/gin-gonic/gin" + "github.com/go-fsnotify/fsnotify" + "io/ioutil" + "net/http" +) + +func main() { + file := "file.txt" + + r := gin.Default() + m := melody.Default() + w, _ := fsnotify.NewWatcher() + + r.GET("/", func(c *gin.Context) { + http.ServeFile(c.Writer, c.Request, "index.html") + }) + + r.GET("/ws", func(c *gin.Context) { + m.HandleRequest(c.Writer, c.Request) + }) + + m.HandleConnect(func(s *melody.Session) { + content, _ := ioutil.ReadFile(file) + s.Write(content) + }) + + go func() { + for { + ev := <-w.Events + if ev.Op == fsnotify.Write { + content, _ := ioutil.ReadFile(ev.Name) + m.Broadcast(content) + } + } + }() + + w.Add(file) + + r.Run(":5000") +} diff --git a/hub.go b/hub.go new file mode 100644 index 0000000..372ca8f --- /dev/null +++ b/hub.go @@ -0,0 +1,42 @@ +package melody + +type hub struct { + sessions map[*Session]bool + broadcast chan *envelope + register chan *Session + unregister chan *Session +} + +func newHub() *hub { + return &hub{ + sessions: make(map[*Session]bool), + broadcast: make(chan *envelope), + register: make(chan *Session), + unregister: make(chan *Session), + } +} + +func (h *hub) run() { + for { + select { + case s := <-h.register: + h.sessions[s] = true + case s := <-h.unregister: + if _, ok := h.sessions[s]; ok { + delete(h.sessions, s) + close(s.output) + s.Conn.Close() + } + case m := <-h.broadcast: + for s := range h.sessions { + if m.filter != nil { + if m.filter(s) { + s.writeMessage(m) + } + } else { + s.writeMessage(m) + } + } + } + } +} diff --git a/melody.go b/melody.go new file mode 100644 index 0000000..0682d55 --- /dev/null +++ b/melody.go @@ -0,0 +1,92 @@ +package melody + +import ( + "github.com/gorilla/websocket" + "net/http" +) + +type handleMessageFunc func(*Session, []byte) +type handleErrorFunc func(*Session, error) +type handleSessionFunc func(*Session) +type filterFunc func(*Session) bool + +type Melody struct { + Config *Config + Upgrader *websocket.Upgrader + MessageHandler handleMessageFunc + ErrorHandler handleErrorFunc + ConnectHandler handleSessionFunc + DisconnectHandler handleSessionFunc + hub *hub +} + +func Default() *Melody { + upgrader := &websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + hub := newHub() + + go hub.run() + + return &Melody{ + Config: newConfig(), + Upgrader: upgrader, + MessageHandler: func(*Session, []byte) {}, + ErrorHandler: func(*Session, error) {}, + ConnectHandler: func(*Session) {}, + DisconnectHandler: func(*Session) {}, + hub: hub, + } +} + +func (m *Melody) HandleConnect(fn handleSessionFunc) { + m.ConnectHandler = fn +} + +func (m *Melody) HandleDisconnect(fn handleSessionFunc) { + m.DisconnectHandler = fn +} + +func (m *Melody) HandleMessage(fn handleMessageFunc) { + m.MessageHandler = fn +} + +func (m *Melody) HandleError(fn handleErrorFunc) { + m.ErrorHandler = fn +} + +func (m *Melody) HandleRequest(w http.ResponseWriter, r *http.Request) error { + conn, err := m.Upgrader.Upgrade(w, r, nil) + + if err != nil { + return err + } + + session := newSession(m.Config, conn) + + m.hub.register <- session + + go m.ConnectHandler(session) + + go session.writePump(m.ErrorHandler) + + session.readPump(m.MessageHandler, m.ErrorHandler) + + m.hub.unregister <- session + + go m.DisconnectHandler(session) + + return nil +} + +func (m *Melody) Broadcast(msg []byte) { + message := &envelope{t: websocket.TextMessage, msg: msg} + m.hub.broadcast <- message +} + +func (m *Melody) BroadcastFilter(fn filterFunc, msg []byte) { + message := &envelope{t: websocket.TextMessage, msg: msg, filter: fn} + m.hub.broadcast <- message +} diff --git a/melody_test.go b/melody_test.go new file mode 100644 index 0000000..b692b46 --- /dev/null +++ b/melody_test.go @@ -0,0 +1,239 @@ +package melody + +import ( + "github.com/gorilla/websocket" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/quick" + "time" +) + +type TestServer struct { + m *Melody +} + +func NewTestServerHandler(handler handleMessageFunc) *TestServer { + m := Default() + m.HandleMessage(handler) + return &TestServer{ + m: m, + } +} + +func NewTestServer() *TestServer { + m := Default() + return &TestServer{ + m: m, + } +} + +func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.m.HandleRequest(w, r) +} + +func NewDialer(url string) (*websocket.Conn, error) { + dialer := &websocket.Dialer{} + conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil) + return conn, err +} + +func TestEcho(t *testing.T) { + echo := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + fn := func(msg string) bool { + conn, err := NewDialer(server.URL) + defer conn.Close() + + if err != nil { + t.Error(err) + return false + } + + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + + _, ret, err := conn.ReadMessage() + + if err != nil { + t.Error(err) + return false + } + + if msg != string(ret) { + t.Errorf("%s should equal %s", msg, string(ret)) + return false + } + + return true + } + + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} + +func TestBroadcast(t *testing.T) { + broadcast := NewTestServer() + broadcast.m.HandleMessage(func(session *Session, msg []byte) { + broadcast.m.Broadcast(msg) + }) + server := httptest.NewServer(broadcast) + defer server.Close() + + n := 10 + + fn := func(msg string) bool { + conn, _ := NewDialer(server.URL) + defer conn.Close() + + listeners := make([]*websocket.Conn, n) + for i := 0; i < n; i++ { + listener, _ := NewDialer(server.URL) + listeners[i] = listener + defer listeners[i].Close() + } + + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + + for i := 0; i < n; i++ { + _, ret, err := listeners[i].ReadMessage() + + if err != nil { + t.Error(err) + return false + } + + if msg != string(ret) { + t.Errorf("%s should equal %s", msg, string(ret)) + return false + } + } + + return true + } + + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} + +func TestPingPong(t *testing.T) { + noecho := NewTestServer() + noecho.m.Config.PongWait = time.Second + noecho.m.Config.PingPeriod = time.Second * 9 / 10 + server := httptest.NewServer(noecho) + defer server.Close() + + conn, err := NewDialer(server.URL) + conn.SetPingHandler(func(string) error { + return nil + }) + defer conn.Close() + + if err != nil { + t.Error(err) + } + + conn.WriteMessage(websocket.TextMessage, []byte("test")) + + _, _, err = conn.ReadMessage() + + if err == nil { + t.Error("there should be an error") + } +} + +/* +func TestBroadcastFilter(t *testing.T) { + echo := NewTestServer() + echo.m.HandleMessage(func(session *Session, msg []byte) { + echo.m.BroadcastFilter(func(s *Session) bool { + //return s == session + return false + }, msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + fn := func(msg string) bool { + conn, err := NewDialer(server.URL) + conn.SetPingHandler(func(string) error { + return nil + }) + defer conn.Close() + + if err != nil { + t.Error(err) + return false + } + + conn.WriteMessage(websocket.TextMessage, []byte(msg)) + + _, ret, err := conn.ReadMessage() + + if err != nil { + t.Error(err) + return false + } + + if msg != string(ret) { + t.Errorf("%s should equal %s", msg, string(ret)) + return false + } + + return true + } + + if err := quick.Check(fn, nil); err != nil { + t.Error(err) + } +} +*/ + +func BenchmarkEcho(b *testing.B) { + echo := NewTestServerHandler(func(session *Session, msg []byte) { + session.Write(msg) + }) + server := httptest.NewServer(echo) + defer server.Close() + + conn, _ := NewDialer(server.URL) + defer conn.Close() + + for i := 0; i < b.N; i++ { + conn.WriteMessage(websocket.TextMessage, []byte("test")) + conn.ReadMessage() + } +} + +func BenchmarkBroadcast(b *testing.B) { + broadcast := NewTestServer() + broadcast.m.HandleMessage(func(session *Session, msg []byte) { + broadcast.m.Broadcast(msg) + }) + server := httptest.NewServer(broadcast) + defer server.Close() + + conn, _ := NewDialer(server.URL) + defer conn.Close() + + n := 10 + listeners := make([]*websocket.Conn, n) + for i := 0; i < n; i++ { + listener, _ := NewDialer(server.URL) + listeners[i] = listener + defer listeners[i].Close() + } + + for i := 0; i < b.N; i++ { + conn.WriteMessage(websocket.TextMessage, []byte("test")) + for i := 0; i < n; i++ { + listeners[i].ReadMessage() + } + } +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..b12847b --- /dev/null +++ b/session.go @@ -0,0 +1,91 @@ +package melody + +import ( + "github.com/gorilla/websocket" + "time" +) + +type Session struct { + Conn *websocket.Conn + output chan *envelope + config *Config +} + +func newSession(config *Config, conn *websocket.Conn) *Session { + return &Session{ + Conn: conn, + output: make(chan *envelope, config.MessageBufferSize), + config: config, + } +} + +func (s *Session) writeMessage(message *envelope) { + s.output <- message +} + +func (s *Session) writeRaw(message *envelope) error { + s.Conn.SetWriteDeadline(time.Now().Add(s.config.WriteWait)) + return s.Conn.WriteMessage(message.t, message.msg) +} + +func (s *Session) close() { + s.writeRaw(&envelope{t: websocket.CloseMessage, msg: []byte{}}) +} + +func (s *Session) ping() { + s.writeMessage(&envelope{t: websocket.PingMessage, msg: []byte{}}) +} + +func (s *Session) writePump(errorHandler handleErrorFunc) { + defer s.Conn.Close() + + ticker := time.NewTicker(s.config.PingPeriod) + defer ticker.Stop() + + for { + select { + case msg, ok := <-s.output: + if !ok { + s.close() + return + } + if err := s.writeRaw(msg); err != nil { + go errorHandler(s, err) + return + } + case <-ticker.C: + s.ping() + } + } +} + +func (s *Session) readPump(messageHandler handleMessageFunc, errorHandler handleErrorFunc) { + defer s.Conn.Close() + + s.Conn.SetReadLimit(s.config.MaxMessageSize) + s.Conn.SetReadDeadline(time.Now().Add(s.config.PongWait)) + + s.Conn.SetPongHandler(func(string) error { + s.Conn.SetReadDeadline(time.Now().Add(s.config.PongWait)) + return nil + }) + + for { + _, message, err := s.Conn.ReadMessage() + + if err != nil { + go errorHandler(s, err) + break + } + + go messageHandler(s, message) + } +} + +func (s *Session) Write(msg []byte) { + s.writeMessage(&envelope{t: websocket.TextMessage, msg: msg}) +} + +func (s *Session) Close() { + s.writeMessage(&envelope{t: websocket.CloseMessage, msg: []byte{}}) +}