You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
4.2 KiB

  1. package suite
  2. import (
  3. "flag"
  4. "fmt"
  5. "os"
  6. "reflect"
  7. "regexp"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. )
  12. var allTestsFilter = func(_, _ string) (bool, error) { return true, nil }
  13. var matchMethod = flag.String("testify.m", "", "regular expression to select tests of the testify suite to run")
  14. // Suite is a basic testing suite with methods for storing and
  15. // retrieving the current *testing.T context.
  16. type Suite struct {
  17. *assert.Assertions
  18. require *require.Assertions
  19. t *testing.T
  20. }
  21. // T retrieves the current *testing.T context.
  22. func (suite *Suite) T() *testing.T {
  23. return suite.t
  24. }
  25. // SetT sets the current *testing.T context.
  26. func (suite *Suite) SetT(t *testing.T) {
  27. suite.t = t
  28. suite.Assertions = assert.New(t)
  29. suite.require = require.New(t)
  30. }
  31. // Require returns a require context for suite.
  32. func (suite *Suite) Require() *require.Assertions {
  33. if suite.require == nil {
  34. suite.require = require.New(suite.T())
  35. }
  36. return suite.require
  37. }
  38. // Assert returns an assert context for suite. Normally, you can call
  39. // `suite.NoError(expected, actual)`, but for situations where the embedded
  40. // methods are overridden (for example, you might want to override
  41. // assert.Assertions with require.Assertions), this method is provided so you
  42. // can call `suite.Assert().NoError()`.
  43. func (suite *Suite) Assert() *assert.Assertions {
  44. if suite.Assertions == nil {
  45. suite.Assertions = assert.New(suite.T())
  46. }
  47. return suite.Assertions
  48. }
  49. func failOnPanic(t *testing.T) {
  50. r := recover()
  51. if r != nil {
  52. t.Errorf("test panicked: %v", r)
  53. t.FailNow()
  54. }
  55. }
  56. // Run provides suite functionality around golang subtests. It should be
  57. // called in place of t.Run(name, func(t *testing.T)) in test suite code.
  58. // The passed-in func will be executed as a subtest with a fresh instance of t.
  59. // Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName.
  60. func (suite *Suite) Run(name string, subtest func()) bool {
  61. oldT := suite.T()
  62. defer suite.SetT(oldT)
  63. return oldT.Run(name, func(t *testing.T) {
  64. suite.SetT(t)
  65. subtest()
  66. })
  67. }
  68. // Run takes a testing suite and runs all of the tests attached
  69. // to it.
  70. func Run(t *testing.T, suite TestingSuite) {
  71. suite.SetT(t)
  72. defer failOnPanic(t)
  73. if setupAllSuite, ok := suite.(SetupAllSuite); ok {
  74. setupAllSuite.SetupSuite()
  75. }
  76. defer func() {
  77. if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok {
  78. tearDownAllSuite.TearDownSuite()
  79. }
  80. }()
  81. methodFinder := reflect.TypeOf(suite)
  82. tests := []testing.InternalTest{}
  83. for index := 0; index < methodFinder.NumMethod(); index++ {
  84. method := methodFinder.Method(index)
  85. ok, err := methodFilter(method.Name)
  86. if err != nil {
  87. fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err)
  88. os.Exit(1)
  89. }
  90. if ok {
  91. test := testing.InternalTest{
  92. Name: method.Name,
  93. F: func(t *testing.T) {
  94. parentT := suite.T()
  95. suite.SetT(t)
  96. defer failOnPanic(t)
  97. if setupTestSuite, ok := suite.(SetupTestSuite); ok {
  98. setupTestSuite.SetupTest()
  99. }
  100. if beforeTestSuite, ok := suite.(BeforeTest); ok {
  101. beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name)
  102. }
  103. defer func() {
  104. if afterTestSuite, ok := suite.(AfterTest); ok {
  105. afterTestSuite.AfterTest(methodFinder.Elem().Name(), method.Name)
  106. }
  107. if tearDownTestSuite, ok := suite.(TearDownTestSuite); ok {
  108. tearDownTestSuite.TearDownTest()
  109. }
  110. suite.SetT(parentT)
  111. }()
  112. method.Func.Call([]reflect.Value{reflect.ValueOf(suite)})
  113. },
  114. }
  115. tests = append(tests, test)
  116. }
  117. }
  118. runTests(t, tests)
  119. }
  120. func runTests(t testing.TB, tests []testing.InternalTest) {
  121. r, ok := t.(runner)
  122. if !ok { // backwards compatibility with Go 1.6 and below
  123. if !testing.RunTests(allTestsFilter, tests) {
  124. t.Fail()
  125. }
  126. return
  127. }
  128. for _, test := range tests {
  129. r.Run(test.Name, test.F)
  130. }
  131. }
  132. // Filtering method according to set regular expression
  133. // specified command-line argument -m
  134. func methodFilter(name string) (bool, error) {
  135. if ok, _ := regexp.MatchString("^Test", name); !ok {
  136. return false, nil
  137. }
  138. return regexp.MatchString(*matchMethod, name)
  139. }
  140. type runner interface {
  141. Run(name string, f func(t *testing.T)) bool
  142. }