websocket 增加多分组 fork https://github.com/olahol/melody
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

327 lines
6.2 KiB

package melody
import (
"github.com/gorilla/websocket"
"net/http"
"net/http/httptest"
"strings"
"testing"
"testing/quick"
"time"
)
type TestServer struct {
m *Melody
}
func NewTestServerHandler(handler handleMessageFunc) *TestServer {
m := New()
m.HandleMessage(handler)
return &TestServer{
m: m,
}
}
func NewTestServer() *TestServer {
m := New()
return &TestServer{
m: m,
}
}
func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.m.HandleRequest(w, r)
}
func NewDialer(url string) (*websocket.Conn, error) {
dialer := &websocket.Dialer{}
conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil)
return conn, err
}
func TestEcho(t *testing.T) {
echo := NewTestServerHandler(func(session *Session, msg []byte) {
session.Write(msg)
})
server := httptest.NewServer(echo)
defer server.Close()
fn := func(msg string) bool {
conn, err := NewDialer(server.URL)
defer conn.Close()
if err != nil {
t.Error(err)
return false
}
conn.WriteMessage(websocket.TextMessage, []byte(msg))
_, ret, err := conn.ReadMessage()
if err != nil {
t.Error(err)
return false
}
if msg != string(ret) {
t.Errorf("%s should equal %s", msg, string(ret))
return false
}
return true
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}
func TestEchoBinary(t *testing.T) {
echo := NewTestServer()
echo.m.HandleMessageBinary(func(session *Session, msg []byte) {
session.WriteBinary(msg)
})
server := httptest.NewServer(echo)
defer server.Close()
fn := func(msg string) bool {
conn, err := NewDialer(server.URL)
defer conn.Close()
if err != nil {
t.Error(err)
return false
}
conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
_, ret, err := conn.ReadMessage()
if err != nil {
t.Error(err)
return false
}
if msg != string(ret) {
t.Errorf("%s should equal %s", msg, string(ret))
return false
}
return true
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}
func TestHandlers(t *testing.T) {
echo := NewTestServer()
echo.m.HandleMessage(func(session *Session, msg []byte) {
session.Write(msg)
})
server := httptest.NewServer(echo)
defer server.Close()
var q *Session
echo.m.HandleConnect(func(session *Session) {
q = session
session.Close()
})
echo.m.HandleDisconnect(func(session *Session) {
if q != session {
t.Error("disconnecting session should be the same as connecting")
}
})
NewDialer(server.URL)
}
func TestUpgrader(t *testing.T) {
broadcast := NewTestServer()
broadcast.m.HandleMessage(func(session *Session, msg []byte) {
session.Write(msg)
})
server := httptest.NewServer(broadcast)
defer server.Close()
broadcast.m.Upgrader = &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return false },
}
broadcast.m.HandleError(func(session *Session, err error) {
if err == nil || err.Error() != "websocket: origin not allowed" {
t.Error("there should be a origin error")
}
})
_, err := NewDialer(server.URL)
if err == nil || err.Error() != "websocket: bad handshake" {
t.Error("there should be a badhandshake error")
}
}
func TestBroadcast(t *testing.T) {
broadcast := NewTestServer()
broadcast.m.HandleMessage(func(session *Session, msg []byte) {
broadcast.m.Broadcast(msg)
})
server := httptest.NewServer(broadcast)
defer server.Close()
n := 10
fn := func(msg string) bool {
conn, _ := NewDialer(server.URL)
defer conn.Close()
listeners := make([]*websocket.Conn, n)
for i := 0; i < n; i++ {
listener, _ := NewDialer(server.URL)
listeners[i] = listener
defer listeners[i].Close()
}
conn.WriteMessage(websocket.TextMessage, []byte(msg))
for i := 0; i < n; i++ {
_, ret, err := listeners[i].ReadMessage()
if err != nil {
t.Error(err)
return false
}
if msg != string(ret) {
t.Errorf("%s should equal %s", msg, string(ret))
return false
}
}
return true
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}
func TestBroadcastOthers(t *testing.T) {
broadcast := NewTestServer()
broadcast.m.HandleMessage(func(session *Session, msg []byte) {
broadcast.m.BroadcastOthers(msg, session)
})
broadcast.m.Config.PongWait = time.Second
broadcast.m.Config.PingPeriod = time.Second * 9 / 10
server := httptest.NewServer(broadcast)
defer server.Close()
n := 10
fn := func(msg string) bool {
conn, _ := NewDialer(server.URL)
defer conn.Close()
listeners := make([]*websocket.Conn, n)
for i := 0; i < n; i++ {
listener, _ := NewDialer(server.URL)
listeners[i] = listener
defer listeners[i].Close()
}
conn.WriteMessage(websocket.TextMessage, []byte(msg))
for i := 0; i < n; i++ {
_, ret, err := listeners[i].ReadMessage()
if err != nil {
t.Error(err)
return false
}
if msg != string(ret) {
t.Errorf("%s should equal %s", msg, string(ret))
return false
}
}
return true
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}
func TestPingPong(t *testing.T) {
noecho := NewTestServer()
noecho.m.Config.PongWait = time.Second
noecho.m.Config.PingPeriod = time.Second * 9 / 10
server := httptest.NewServer(noecho)
defer server.Close()
conn, err := NewDialer(server.URL)
conn.SetPingHandler(func(string) error {
return nil
})
defer conn.Close()
if err != nil {
t.Error(err)
}
conn.WriteMessage(websocket.TextMessage, []byte("test"))
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("there should be an error")
}
}
func TestBroadcastFilter(t *testing.T) {
broadcast := NewTestServer()
broadcast.m.HandleMessage(func(session *Session, msg []byte) {
broadcast.m.BroadcastFilter(msg, func(q *Session) bool {
return session == q
})
})
server := httptest.NewServer(broadcast)
defer server.Close()
fn := func(msg string) bool {
conn, err := NewDialer(server.URL)
defer conn.Close()
if err != nil {
t.Error(err)
return false
}
conn.WriteMessage(websocket.TextMessage, []byte(msg))
_, ret, err := conn.ReadMessage()
if err != nil {
t.Error(err)
return false
}
if msg != string(ret) {
t.Errorf("%s should equal %s", msg, string(ret))
return false
}
return true
}
if err := quick.Check(fn, nil); err != nil {
t.Error(err)
}
}