websocket 增加多分组 fork https://github.com/olahol/melody
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

597 lines
12 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
9 years ago
  1. package melody
  2. import (
  3. "bytes"
  4. "github.com/gorilla/websocket"
  5. "net/http"
  6. "net/http/httptest"
  7. "strconv"
  8. "strings"
  9. "testing"
  10. "testing/quick"
  11. "time"
  12. )
  13. type TestServer struct {
  14. m *Melody
  15. }
  16. func NewTestServerHandler(handler handleMessageFunc) *TestServer {
  17. m := New()
  18. m.HandleMessage(handler)
  19. return &TestServer{
  20. m: m,
  21. }
  22. }
  23. func NewTestServer() *TestServer {
  24. m := New()
  25. return &TestServer{
  26. m: m,
  27. }
  28. }
  29. func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  30. s.m.HandleRequest(w, r)
  31. }
  32. func NewDialer(url string) (*websocket.Conn, error) {
  33. dialer := &websocket.Dialer{}
  34. conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil)
  35. return conn, err
  36. }
  37. func TestEcho(t *testing.T) {
  38. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  39. session.Write(msg)
  40. })
  41. server := httptest.NewServer(echo)
  42. defer server.Close()
  43. fn := func(msg string) bool {
  44. conn, err := NewDialer(server.URL)
  45. defer conn.Close()
  46. if err != nil {
  47. t.Error(err)
  48. return false
  49. }
  50. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  51. _, ret, err := conn.ReadMessage()
  52. if err != nil {
  53. t.Error(err)
  54. return false
  55. }
  56. if msg != string(ret) {
  57. t.Errorf("%s should equal %s", msg, string(ret))
  58. return false
  59. }
  60. return true
  61. }
  62. if err := quick.Check(fn, nil); err != nil {
  63. t.Error(err)
  64. }
  65. }
  66. func TestEchoBinary(t *testing.T) {
  67. echo := NewTestServer()
  68. echo.m.HandleMessageBinary(func(session *Session, msg []byte) {
  69. session.WriteBinary(msg)
  70. })
  71. server := httptest.NewServer(echo)
  72. defer server.Close()
  73. fn := func(msg string) bool {
  74. conn, err := NewDialer(server.URL)
  75. defer conn.Close()
  76. if err != nil {
  77. t.Error(err)
  78. return false
  79. }
  80. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  81. _, ret, err := conn.ReadMessage()
  82. if err != nil {
  83. t.Error(err)
  84. return false
  85. }
  86. if msg != string(ret) {
  87. t.Errorf("%s should equal %s", msg, string(ret))
  88. return false
  89. }
  90. return true
  91. }
  92. if err := quick.Check(fn, nil); err != nil {
  93. t.Error(err)
  94. }
  95. }
  96. func TestHandlers(t *testing.T) {
  97. echo := NewTestServer()
  98. echo.m.HandleMessage(func(session *Session, msg []byte) {
  99. session.Write(msg)
  100. })
  101. server := httptest.NewServer(echo)
  102. defer server.Close()
  103. var q *Session
  104. echo.m.HandleConnect(func(session *Session) {
  105. q = session
  106. session.Close()
  107. })
  108. echo.m.HandleDisconnect(func(session *Session) {
  109. if q != session {
  110. t.Error("disconnecting session should be the same as connecting")
  111. }
  112. })
  113. NewDialer(server.URL)
  114. }
  115. func TestMetadata(t *testing.T) {
  116. echo := NewTestServer()
  117. echo.m.HandleConnect(func(session *Session) {
  118. session.Set("stamp", time.Now().UnixNano())
  119. })
  120. echo.m.HandleMessage(func(session *Session, msg []byte) {
  121. stamp := session.MustGet("stamp").(int64)
  122. session.Write([]byte(strconv.Itoa(int(stamp))))
  123. })
  124. server := httptest.NewServer(echo)
  125. defer server.Close()
  126. fn := func(msg string) bool {
  127. conn, err := NewDialer(server.URL)
  128. defer conn.Close()
  129. if err != nil {
  130. t.Error(err)
  131. return false
  132. }
  133. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  134. _, ret, err := conn.ReadMessage()
  135. if err != nil {
  136. t.Error(err)
  137. return false
  138. }
  139. stamp, err := strconv.Atoi(string(ret))
  140. if err != nil {
  141. t.Error(err)
  142. return false
  143. }
  144. diff := int(time.Now().UnixNano()) - stamp
  145. if diff <= 0 {
  146. t.Errorf("diff should be above 0 %d", diff)
  147. return false
  148. }
  149. return true
  150. }
  151. if err := quick.Check(fn, nil); err != nil {
  152. t.Error(err)
  153. }
  154. }
  155. func TestUpgrader(t *testing.T) {
  156. broadcast := NewTestServer()
  157. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  158. session.Write(msg)
  159. })
  160. server := httptest.NewServer(broadcast)
  161. defer server.Close()
  162. broadcast.m.Upgrader = &websocket.Upgrader{
  163. ReadBufferSize: 1024,
  164. WriteBufferSize: 1024,
  165. CheckOrigin: func(r *http.Request) bool { return false },
  166. }
  167. broadcast.m.HandleError(func(session *Session, err error) {
  168. if err == nil || err.Error() != "websocket: origin not allowed" {
  169. t.Error("there should be a origin error")
  170. }
  171. })
  172. _, err := NewDialer(server.URL)
  173. if err == nil || err.Error() != "websocket: bad handshake" {
  174. t.Error("there should be a badhandshake error")
  175. }
  176. }
  177. func TestBroadcast(t *testing.T) {
  178. broadcast := NewTestServer()
  179. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  180. broadcast.m.Broadcast(msg)
  181. })
  182. server := httptest.NewServer(broadcast)
  183. defer server.Close()
  184. n := 10
  185. fn := func(msg string) bool {
  186. conn, _ := NewDialer(server.URL)
  187. defer conn.Close()
  188. listeners := make([]*websocket.Conn, n)
  189. for i := 0; i < n; i++ {
  190. listener, _ := NewDialer(server.URL)
  191. listeners[i] = listener
  192. defer listeners[i].Close()
  193. }
  194. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  195. for i := 0; i < n; i++ {
  196. _, ret, err := listeners[i].ReadMessage()
  197. if err != nil {
  198. t.Error(err)
  199. return false
  200. }
  201. if msg != string(ret) {
  202. t.Errorf("%s should equal %s", msg, string(ret))
  203. return false
  204. }
  205. }
  206. return true
  207. }
  208. if !fn("test") {
  209. t.Errorf("should not be false")
  210. }
  211. }
  212. func TestBroadcastBinary(t *testing.T) {
  213. broadcast := NewTestServer()
  214. broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) {
  215. broadcast.m.BroadcastBinary(msg)
  216. })
  217. server := httptest.NewServer(broadcast)
  218. defer server.Close()
  219. n := 10
  220. fn := func(msg []byte) bool {
  221. conn, _ := NewDialer(server.URL)
  222. defer conn.Close()
  223. listeners := make([]*websocket.Conn, n)
  224. for i := 0; i < n; i++ {
  225. listener, _ := NewDialer(server.URL)
  226. listeners[i] = listener
  227. defer listeners[i].Close()
  228. }
  229. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  230. for i := 0; i < n; i++ {
  231. messageType, ret, err := listeners[i].ReadMessage()
  232. if err != nil {
  233. t.Error(err)
  234. return false
  235. }
  236. if messageType != websocket.BinaryMessage {
  237. t.Errorf("message type should be BinaryMessage")
  238. return false
  239. }
  240. if !bytes.Equal(msg, ret) {
  241. t.Errorf("%v should equal %v", msg, ret)
  242. return false
  243. }
  244. }
  245. return true
  246. }
  247. if !fn([]byte{2, 3, 5, 7, 11}) {
  248. t.Errorf("should not be false")
  249. }
  250. }
  251. func TestBroadcastOthers(t *testing.T) {
  252. broadcast := NewTestServer()
  253. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  254. broadcast.m.BroadcastOthers(msg, session)
  255. })
  256. broadcast.m.Config.PongWait = time.Second
  257. broadcast.m.Config.PingPeriod = time.Second * 9 / 10
  258. server := httptest.NewServer(broadcast)
  259. defer server.Close()
  260. n := 10
  261. fn := func(msg string) bool {
  262. conn, _ := NewDialer(server.URL)
  263. defer conn.Close()
  264. listeners := make([]*websocket.Conn, n)
  265. for i := 0; i < n; i++ {
  266. listener, _ := NewDialer(server.URL)
  267. listeners[i] = listener
  268. defer listeners[i].Close()
  269. }
  270. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  271. for i := 0; i < n; i++ {
  272. _, ret, err := listeners[i].ReadMessage()
  273. if err != nil {
  274. t.Error(err)
  275. return false
  276. }
  277. if msg != string(ret) {
  278. t.Errorf("%s should equal %s", msg, string(ret))
  279. return false
  280. }
  281. }
  282. return true
  283. }
  284. if !fn("test") {
  285. t.Errorf("should not be false")
  286. }
  287. }
  288. func TestBroadcastBinaryOthers(t *testing.T) {
  289. broadcast := NewTestServer()
  290. broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) {
  291. broadcast.m.BroadcastBinaryOthers(msg, session)
  292. })
  293. broadcast.m.Config.PongWait = time.Second
  294. broadcast.m.Config.PingPeriod = time.Second * 9 / 10
  295. server := httptest.NewServer(broadcast)
  296. defer server.Close()
  297. n := 10
  298. fn := func(msg []byte) bool {
  299. conn, _ := NewDialer(server.URL)
  300. defer conn.Close()
  301. listeners := make([]*websocket.Conn, n)
  302. for i := 0; i < n; i++ {
  303. listener, _ := NewDialer(server.URL)
  304. listeners[i] = listener
  305. defer listeners[i].Close()
  306. }
  307. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  308. for i := 0; i < n; i++ {
  309. messageType, ret, err := listeners[i].ReadMessage()
  310. if err != nil {
  311. t.Error(err)
  312. return false
  313. }
  314. if messageType != websocket.BinaryMessage {
  315. t.Errorf("message type should be BinaryMessage")
  316. return false
  317. }
  318. if !bytes.Equal(msg, ret) {
  319. t.Errorf("%v should equal %v", msg, ret)
  320. return false
  321. }
  322. }
  323. return true
  324. }
  325. if !fn([]byte{2, 3, 5, 7, 11}) {
  326. t.Errorf("should not be false")
  327. }
  328. }
  329. func TestPingPong(t *testing.T) {
  330. noecho := NewTestServer()
  331. noecho.m.Config.PongWait = time.Second
  332. noecho.m.Config.PingPeriod = time.Second * 9 / 10
  333. server := httptest.NewServer(noecho)
  334. defer server.Close()
  335. conn, err := NewDialer(server.URL)
  336. conn.SetPingHandler(func(string) error {
  337. return nil
  338. })
  339. defer conn.Close()
  340. if err != nil {
  341. t.Error(err)
  342. }
  343. conn.WriteMessage(websocket.TextMessage, []byte("test"))
  344. _, _, err = conn.ReadMessage()
  345. if err == nil {
  346. t.Error("there should be an error")
  347. }
  348. }
  349. func TestBroadcastFilter(t *testing.T) {
  350. broadcast := NewTestServer()
  351. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  352. broadcast.m.BroadcastFilter(msg, func(q *Session) bool {
  353. return session == q
  354. })
  355. })
  356. server := httptest.NewServer(broadcast)
  357. defer server.Close()
  358. fn := func(msg string) bool {
  359. conn, err := NewDialer(server.URL)
  360. defer conn.Close()
  361. if err != nil {
  362. t.Error(err)
  363. return false
  364. }
  365. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  366. _, ret, err := conn.ReadMessage()
  367. if err != nil {
  368. t.Error(err)
  369. return false
  370. }
  371. if msg != string(ret) {
  372. t.Errorf("%s should equal %s", msg, string(ret))
  373. return false
  374. }
  375. return true
  376. }
  377. if !fn("test") {
  378. t.Errorf("should not be false")
  379. }
  380. }
  381. func TestBroadcastBinaryFilter(t *testing.T) {
  382. broadcast := NewTestServer()
  383. broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) {
  384. broadcast.m.BroadcastBinaryFilter(msg, func(q *Session) bool {
  385. return session == q
  386. })
  387. })
  388. server := httptest.NewServer(broadcast)
  389. defer server.Close()
  390. fn := func(msg []byte) bool {
  391. conn, err := NewDialer(server.URL)
  392. defer conn.Close()
  393. if err != nil {
  394. t.Error(err)
  395. return false
  396. }
  397. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  398. messageType, ret, err := conn.ReadMessage()
  399. if err != nil {
  400. t.Error(err)
  401. return false
  402. }
  403. if messageType != websocket.BinaryMessage {
  404. t.Errorf("message type should be BinaryMessage")
  405. return false
  406. }
  407. if !bytes.Equal(msg, ret) {
  408. t.Errorf("%v should equal %v", msg, ret)
  409. return false
  410. }
  411. return true
  412. }
  413. if !fn([]byte{2, 3, 5, 7, 11}) {
  414. t.Errorf("should not be false")
  415. }
  416. }
  417. func TestStop(t *testing.T) {
  418. noecho := NewTestServer()
  419. server := httptest.NewServer(noecho)
  420. defer server.Close()
  421. conn, err := NewDialer(server.URL)
  422. defer conn.Close()
  423. if err != nil {
  424. t.Error(err)
  425. }
  426. noecho.m.Close()
  427. }
  428. func TestSmallMessageBuffer(t *testing.T) {
  429. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  430. session.Write(msg)
  431. })
  432. echo.m.Config.MessageBufferSize = 0
  433. echo.m.HandleError(func(s *Session, err error) {
  434. if err == nil {
  435. t.Error("there should be a buffer full error here")
  436. }
  437. })
  438. server := httptest.NewServer(echo)
  439. defer server.Close()
  440. conn, err := NewDialer(server.URL)
  441. defer conn.Close()
  442. if err != nil {
  443. t.Error(err)
  444. }
  445. conn.WriteMessage(websocket.TextMessage, []byte("12345"))
  446. }
  447. func TestPong(t *testing.T) {
  448. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  449. session.Write(msg)
  450. })
  451. echo.m.Config.PongWait = time.Second
  452. echo.m.Config.PingPeriod = time.Second * 9 / 10
  453. server := httptest.NewServer(echo)
  454. defer server.Close()
  455. conn, err := NewDialer(server.URL)
  456. defer conn.Close()
  457. if err != nil {
  458. t.Error(err)
  459. }
  460. fired := false
  461. echo.m.HandlePong(func(s *Session) {
  462. fired = true
  463. })
  464. conn.WriteMessage(websocket.PongMessage, nil)
  465. time.Sleep(time.Millisecond)
  466. if !fired {
  467. t.Error("should have fired pong handler")
  468. }
  469. }