websocket 增加多分组 fork https://github.com/olahol/melody
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

724 lines
14 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
9 years ago
  1. package melody
  2. import (
  3. "bytes"
  4. "github.com/gorilla/websocket"
  5. "math/rand"
  6. "net/http"
  7. "net/http/httptest"
  8. "strconv"
  9. "strings"
  10. "testing"
  11. "testing/quick"
  12. "time"
  13. )
  14. type TestServer struct {
  15. m *Melody
  16. }
  17. func NewTestServerHandler(handler handleMessageFunc) *TestServer {
  18. m := New()
  19. m.HandleMessage(handler)
  20. return &TestServer{
  21. m: m,
  22. }
  23. }
  24. func NewTestServer() *TestServer {
  25. m := New()
  26. return &TestServer{
  27. m: m,
  28. }
  29. }
  30. func (s *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  31. s.m.HandleRequest(w, r)
  32. }
  33. func NewDialer(url string) (*websocket.Conn, error) {
  34. dialer := &websocket.Dialer{}
  35. conn, _, err := dialer.Dial(strings.Replace(url, "http", "ws", 1), nil)
  36. return conn, err
  37. }
  38. func TestEcho(t *testing.T) {
  39. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  40. session.Write(msg)
  41. })
  42. server := httptest.NewServer(echo)
  43. defer server.Close()
  44. fn := func(msg string) bool {
  45. conn, err := NewDialer(server.URL)
  46. defer conn.Close()
  47. if err != nil {
  48. t.Error(err)
  49. return false
  50. }
  51. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  52. _, ret, err := conn.ReadMessage()
  53. if err != nil {
  54. t.Error(err)
  55. return false
  56. }
  57. if msg != string(ret) {
  58. t.Errorf("%s should equal %s", msg, string(ret))
  59. return false
  60. }
  61. return true
  62. }
  63. if err := quick.Check(fn, nil); err != nil {
  64. t.Error(err)
  65. }
  66. }
  67. func TestWriteClosed(t *testing.T) {
  68. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  69. session.Write(msg)
  70. })
  71. server := httptest.NewServer(echo)
  72. defer server.Close()
  73. fn := func(msg string) bool {
  74. conn, err := NewDialer(server.URL)
  75. if err != nil {
  76. t.Error(err)
  77. return false
  78. }
  79. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  80. echo.m.HandleConnect(func(s *Session) {
  81. s.Close()
  82. })
  83. echo.m.HandleDisconnect(func(s *Session) {
  84. err := s.Write([]byte("hello world"))
  85. if err == nil {
  86. t.Error("should be an error")
  87. }
  88. })
  89. return true
  90. }
  91. if err := quick.Check(fn, nil); err != nil {
  92. t.Error(err)
  93. }
  94. }
  95. func TestLen(t *testing.T) {
  96. rand.Seed(time.Now().UnixNano())
  97. connect := int(rand.Int31n(100))
  98. disconnect := rand.Float32()
  99. conns := make([]*websocket.Conn, connect)
  100. defer func() {
  101. for _, conn := range conns {
  102. if conn != nil {
  103. conn.Close()
  104. }
  105. }
  106. }()
  107. echo := NewTestServerHandler(func(session *Session, msg []byte) {})
  108. server := httptest.NewServer(echo)
  109. defer server.Close()
  110. disconnected := 0
  111. for i := 0; i < connect; i++ {
  112. conn, err := NewDialer(server.URL)
  113. if err != nil {
  114. t.Error(err)
  115. }
  116. if rand.Float32() < disconnect {
  117. conns[i] = nil
  118. disconnected += 1
  119. conn.Close()
  120. continue
  121. }
  122. conns[i] = conn
  123. }
  124. time.Sleep(time.Millisecond)
  125. connected := connect - disconnected
  126. if echo.m.Len() != connected {
  127. t.Errorf("melody len %d should equal %d", echo.m.Len(), connected)
  128. }
  129. }
  130. func TestEchoBinary(t *testing.T) {
  131. echo := NewTestServer()
  132. echo.m.HandleMessageBinary(func(session *Session, msg []byte) {
  133. session.WriteBinary(msg)
  134. })
  135. server := httptest.NewServer(echo)
  136. defer server.Close()
  137. fn := func(msg string) bool {
  138. conn, err := NewDialer(server.URL)
  139. defer conn.Close()
  140. if err != nil {
  141. t.Error(err)
  142. return false
  143. }
  144. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  145. _, ret, err := conn.ReadMessage()
  146. if err != nil {
  147. t.Error(err)
  148. return false
  149. }
  150. if msg != string(ret) {
  151. t.Errorf("%s should equal %s", msg, string(ret))
  152. return false
  153. }
  154. return true
  155. }
  156. if err := quick.Check(fn, nil); err != nil {
  157. t.Error(err)
  158. }
  159. }
  160. func TestHandlers(t *testing.T) {
  161. echo := NewTestServer()
  162. echo.m.HandleMessage(func(session *Session, msg []byte) {
  163. session.Write(msg)
  164. })
  165. server := httptest.NewServer(echo)
  166. defer server.Close()
  167. var q *Session
  168. echo.m.HandleConnect(func(session *Session) {
  169. q = session
  170. session.Close()
  171. })
  172. echo.m.HandleDisconnect(func(session *Session) {
  173. if q != session {
  174. t.Error("disconnecting session should be the same as connecting")
  175. }
  176. })
  177. NewDialer(server.URL)
  178. }
  179. func TestMetadata(t *testing.T) {
  180. echo := NewTestServer()
  181. echo.m.HandleConnect(func(session *Session) {
  182. session.Set("stamp", time.Now().UnixNano())
  183. })
  184. echo.m.HandleMessage(func(session *Session, msg []byte) {
  185. stamp := session.MustGet("stamp").(int64)
  186. session.Write([]byte(strconv.Itoa(int(stamp))))
  187. })
  188. server := httptest.NewServer(echo)
  189. defer server.Close()
  190. fn := func(msg string) bool {
  191. conn, err := NewDialer(server.URL)
  192. defer conn.Close()
  193. if err != nil {
  194. t.Error(err)
  195. return false
  196. }
  197. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  198. _, ret, err := conn.ReadMessage()
  199. if err != nil {
  200. t.Error(err)
  201. return false
  202. }
  203. stamp, err := strconv.Atoi(string(ret))
  204. if err != nil {
  205. t.Error(err)
  206. return false
  207. }
  208. diff := int(time.Now().UnixNano()) - stamp
  209. if diff <= 0 {
  210. t.Errorf("diff should be above 0 %d", diff)
  211. return false
  212. }
  213. return true
  214. }
  215. if err := quick.Check(fn, nil); err != nil {
  216. t.Error(err)
  217. }
  218. }
  219. func TestUpgrader(t *testing.T) {
  220. broadcast := NewTestServer()
  221. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  222. session.Write(msg)
  223. })
  224. server := httptest.NewServer(broadcast)
  225. defer server.Close()
  226. broadcast.m.Upgrader = &websocket.Upgrader{
  227. ReadBufferSize: 1024,
  228. WriteBufferSize: 1024,
  229. CheckOrigin: func(r *http.Request) bool { return false },
  230. }
  231. broadcast.m.HandleError(func(session *Session, err error) {
  232. if err == nil || err.Error() != "websocket: origin not allowed" {
  233. t.Error("there should be a origin error")
  234. }
  235. })
  236. _, err := NewDialer(server.URL)
  237. if err == nil || err.Error() != "websocket: bad handshake" {
  238. t.Error("there should be a badhandshake error")
  239. }
  240. }
  241. func TestBroadcast(t *testing.T) {
  242. broadcast := NewTestServer()
  243. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  244. broadcast.m.Broadcast(msg)
  245. })
  246. server := httptest.NewServer(broadcast)
  247. defer server.Close()
  248. n := 10
  249. fn := func(msg string) bool {
  250. conn, _ := NewDialer(server.URL)
  251. defer conn.Close()
  252. listeners := make([]*websocket.Conn, n)
  253. for i := 0; i < n; i++ {
  254. listener, _ := NewDialer(server.URL)
  255. listeners[i] = listener
  256. defer listeners[i].Close()
  257. }
  258. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  259. for i := 0; i < n; i++ {
  260. _, ret, err := listeners[i].ReadMessage()
  261. if err != nil {
  262. t.Error(err)
  263. return false
  264. }
  265. if msg != string(ret) {
  266. t.Errorf("%s should equal %s", msg, string(ret))
  267. return false
  268. }
  269. }
  270. return true
  271. }
  272. if !fn("test") {
  273. t.Errorf("should not be false")
  274. }
  275. }
  276. func TestBroadcastBinary(t *testing.T) {
  277. broadcast := NewTestServer()
  278. broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) {
  279. broadcast.m.BroadcastBinary(msg)
  280. })
  281. server := httptest.NewServer(broadcast)
  282. defer server.Close()
  283. n := 10
  284. fn := func(msg []byte) bool {
  285. conn, _ := NewDialer(server.URL)
  286. defer conn.Close()
  287. listeners := make([]*websocket.Conn, n)
  288. for i := 0; i < n; i++ {
  289. listener, _ := NewDialer(server.URL)
  290. listeners[i] = listener
  291. defer listeners[i].Close()
  292. }
  293. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  294. for i := 0; i < n; i++ {
  295. messageType, ret, err := listeners[i].ReadMessage()
  296. if err != nil {
  297. t.Error(err)
  298. return false
  299. }
  300. if messageType != websocket.BinaryMessage {
  301. t.Errorf("message type should be BinaryMessage")
  302. return false
  303. }
  304. if !bytes.Equal(msg, ret) {
  305. t.Errorf("%v should equal %v", msg, ret)
  306. return false
  307. }
  308. }
  309. return true
  310. }
  311. if !fn([]byte{2, 3, 5, 7, 11}) {
  312. t.Errorf("should not be false")
  313. }
  314. }
  315. func TestBroadcastOthers(t *testing.T) {
  316. broadcast := NewTestServer()
  317. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  318. broadcast.m.BroadcastOthers(msg, session)
  319. })
  320. broadcast.m.Config.PongWait = time.Second
  321. broadcast.m.Config.PingPeriod = time.Second * 9 / 10
  322. server := httptest.NewServer(broadcast)
  323. defer server.Close()
  324. n := 10
  325. fn := func(msg string) bool {
  326. conn, _ := NewDialer(server.URL)
  327. defer conn.Close()
  328. listeners := make([]*websocket.Conn, n)
  329. for i := 0; i < n; i++ {
  330. listener, _ := NewDialer(server.URL)
  331. listeners[i] = listener
  332. defer listeners[i].Close()
  333. }
  334. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  335. for i := 0; i < n; i++ {
  336. _, ret, err := listeners[i].ReadMessage()
  337. if err != nil {
  338. t.Error(err)
  339. return false
  340. }
  341. if msg != string(ret) {
  342. t.Errorf("%s should equal %s", msg, string(ret))
  343. return false
  344. }
  345. }
  346. return true
  347. }
  348. if !fn("test") {
  349. t.Errorf("should not be false")
  350. }
  351. }
  352. func TestBroadcastBinaryOthers(t *testing.T) {
  353. broadcast := NewTestServer()
  354. broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) {
  355. broadcast.m.BroadcastBinaryOthers(msg, session)
  356. })
  357. broadcast.m.Config.PongWait = time.Second
  358. broadcast.m.Config.PingPeriod = time.Second * 9 / 10
  359. server := httptest.NewServer(broadcast)
  360. defer server.Close()
  361. n := 10
  362. fn := func(msg []byte) bool {
  363. conn, _ := NewDialer(server.URL)
  364. defer conn.Close()
  365. listeners := make([]*websocket.Conn, n)
  366. for i := 0; i < n; i++ {
  367. listener, _ := NewDialer(server.URL)
  368. listeners[i] = listener
  369. defer listeners[i].Close()
  370. }
  371. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  372. for i := 0; i < n; i++ {
  373. messageType, ret, err := listeners[i].ReadMessage()
  374. if err != nil {
  375. t.Error(err)
  376. return false
  377. }
  378. if messageType != websocket.BinaryMessage {
  379. t.Errorf("message type should be BinaryMessage")
  380. return false
  381. }
  382. if !bytes.Equal(msg, ret) {
  383. t.Errorf("%v should equal %v", msg, ret)
  384. return false
  385. }
  386. }
  387. return true
  388. }
  389. if !fn([]byte{2, 3, 5, 7, 11}) {
  390. t.Errorf("should not be false")
  391. }
  392. }
  393. func TestPingPong(t *testing.T) {
  394. noecho := NewTestServer()
  395. noecho.m.Config.PongWait = time.Second
  396. noecho.m.Config.PingPeriod = time.Second * 9 / 10
  397. server := httptest.NewServer(noecho)
  398. defer server.Close()
  399. conn, err := NewDialer(server.URL)
  400. conn.SetPingHandler(func(string) error {
  401. return nil
  402. })
  403. defer conn.Close()
  404. if err != nil {
  405. t.Error(err)
  406. }
  407. conn.WriteMessage(websocket.TextMessage, []byte("test"))
  408. _, _, err = conn.ReadMessage()
  409. if err == nil {
  410. t.Error("there should be an error")
  411. }
  412. }
  413. func TestBroadcastFilter(t *testing.T) {
  414. broadcast := NewTestServer()
  415. broadcast.m.HandleMessage(func(session *Session, msg []byte) {
  416. broadcast.m.BroadcastFilter(msg, func(q *Session) bool {
  417. return session == q
  418. })
  419. })
  420. server := httptest.NewServer(broadcast)
  421. defer server.Close()
  422. fn := func(msg string) bool {
  423. conn, err := NewDialer(server.URL)
  424. defer conn.Close()
  425. if err != nil {
  426. t.Error(err)
  427. return false
  428. }
  429. conn.WriteMessage(websocket.TextMessage, []byte(msg))
  430. _, ret, err := conn.ReadMessage()
  431. if err != nil {
  432. t.Error(err)
  433. return false
  434. }
  435. if msg != string(ret) {
  436. t.Errorf("%s should equal %s", msg, string(ret))
  437. return false
  438. }
  439. return true
  440. }
  441. if !fn("test") {
  442. t.Errorf("should not be false")
  443. }
  444. }
  445. func TestBroadcastBinaryFilter(t *testing.T) {
  446. broadcast := NewTestServer()
  447. broadcast.m.HandleMessageBinary(func(session *Session, msg []byte) {
  448. broadcast.m.BroadcastBinaryFilter(msg, func(q *Session) bool {
  449. return session == q
  450. })
  451. })
  452. server := httptest.NewServer(broadcast)
  453. defer server.Close()
  454. fn := func(msg []byte) bool {
  455. conn, err := NewDialer(server.URL)
  456. defer conn.Close()
  457. if err != nil {
  458. t.Error(err)
  459. return false
  460. }
  461. conn.WriteMessage(websocket.BinaryMessage, []byte(msg))
  462. messageType, ret, err := conn.ReadMessage()
  463. if err != nil {
  464. t.Error(err)
  465. return false
  466. }
  467. if messageType != websocket.BinaryMessage {
  468. t.Errorf("message type should be BinaryMessage")
  469. return false
  470. }
  471. if !bytes.Equal(msg, ret) {
  472. t.Errorf("%v should equal %v", msg, ret)
  473. return false
  474. }
  475. return true
  476. }
  477. if !fn([]byte{2, 3, 5, 7, 11}) {
  478. t.Errorf("should not be false")
  479. }
  480. }
  481. func TestStop(t *testing.T) {
  482. noecho := NewTestServer()
  483. server := httptest.NewServer(noecho)
  484. defer server.Close()
  485. conn, err := NewDialer(server.URL)
  486. defer conn.Close()
  487. if err != nil {
  488. t.Error(err)
  489. }
  490. noecho.m.Close()
  491. }
  492. func TestSmallMessageBuffer(t *testing.T) {
  493. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  494. session.Write(msg)
  495. })
  496. echo.m.Config.MessageBufferSize = 0
  497. echo.m.HandleError(func(s *Session, err error) {
  498. if err == nil {
  499. t.Error("there should be a buffer full error here")
  500. }
  501. })
  502. server := httptest.NewServer(echo)
  503. defer server.Close()
  504. conn, err := NewDialer(server.URL)
  505. defer conn.Close()
  506. if err != nil {
  507. t.Error(err)
  508. }
  509. conn.WriteMessage(websocket.TextMessage, []byte("12345"))
  510. }
  511. func TestPong(t *testing.T) {
  512. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  513. session.Write(msg)
  514. })
  515. echo.m.Config.PongWait = time.Second
  516. echo.m.Config.PingPeriod = time.Second * 9 / 10
  517. server := httptest.NewServer(echo)
  518. defer server.Close()
  519. conn, err := NewDialer(server.URL)
  520. defer conn.Close()
  521. if err != nil {
  522. t.Error(err)
  523. }
  524. fired := false
  525. echo.m.HandlePong(func(s *Session) {
  526. fired = true
  527. })
  528. conn.WriteMessage(websocket.PongMessage, nil)
  529. time.Sleep(time.Millisecond)
  530. if !fired {
  531. t.Error("should have fired pong handler")
  532. }
  533. }
  534. func BenchmarkSessionWrite(b *testing.B) {
  535. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  536. session.Write(msg)
  537. })
  538. server := httptest.NewServer(echo)
  539. conn, _ := NewDialer(server.URL)
  540. defer server.Close()
  541. defer conn.Close()
  542. for n := 0; n < b.N; n++ {
  543. conn.WriteMessage(websocket.TextMessage, []byte("test"))
  544. conn.ReadMessage()
  545. }
  546. }
  547. func BenchmarkBroadcast(b *testing.B) {
  548. echo := NewTestServerHandler(func(session *Session, msg []byte) {
  549. session.Write(msg)
  550. })
  551. server := httptest.NewServer(echo)
  552. defer server.Close()
  553. conns := make([]*websocket.Conn, 0)
  554. num := 100
  555. for i := 0; i < num; i++ {
  556. conn, _ := NewDialer(server.URL)
  557. conns = append(conns, conn)
  558. }
  559. for n := 0; n < b.N; n++ {
  560. echo.m.Broadcast([]byte("test"))
  561. for i := 0; i < num; i++ {
  562. conns[i].ReadMessage()
  563. }
  564. }
  565. for i := 0; i < num; i++ {
  566. conns[i].Close()
  567. }
  568. }