You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
7.4 KiB

  1. // This program reads all assertion functions from the assert package and
  2. // automatically generates the corresponding requires and forwarded assertions
  3. package main
  4. import (
  5. "bytes"
  6. "flag"
  7. "fmt"
  8. "go/ast"
  9. "go/build"
  10. "go/doc"
  11. "go/format"
  12. "go/importer"
  13. "go/parser"
  14. "go/token"
  15. "go/types"
  16. "io"
  17. "io/ioutil"
  18. "log"
  19. "os"
  20. "path"
  21. "regexp"
  22. "strings"
  23. "text/template"
  24. "github.com/ernesto-jimenez/gogen/imports"
  25. )
  26. var (
  27. pkg = flag.String("assert-path", "github.com/stretchr/testify/assert", "Path to the assert package")
  28. includeF = flag.Bool("include-format-funcs", false, "include format functions such as Errorf and Equalf")
  29. outputPkg = flag.String("output-package", "", "package for the resulting code")
  30. tmplFile = flag.String("template", "", "What file to load the function template from")
  31. out = flag.String("out", "", "What file to write the source code to")
  32. )
  33. func main() {
  34. flag.Parse()
  35. scope, docs, err := parsePackageSource(*pkg)
  36. if err != nil {
  37. log.Fatal(err)
  38. }
  39. importer, funcs, err := analyzeCode(scope, docs)
  40. if err != nil {
  41. log.Fatal(err)
  42. }
  43. if err := generateCode(importer, funcs); err != nil {
  44. log.Fatal(err)
  45. }
  46. }
  47. func generateCode(importer imports.Importer, funcs []testFunc) error {
  48. buff := bytes.NewBuffer(nil)
  49. tmplHead, tmplFunc, err := parseTemplates()
  50. if err != nil {
  51. return err
  52. }
  53. // Generate header
  54. if err := tmplHead.Execute(buff, struct {
  55. Name string
  56. Imports map[string]string
  57. }{
  58. *outputPkg,
  59. importer.Imports(),
  60. }); err != nil {
  61. return err
  62. }
  63. // Generate funcs
  64. for _, fn := range funcs {
  65. buff.Write([]byte("\n\n"))
  66. if err := tmplFunc.Execute(buff, &fn); err != nil {
  67. return err
  68. }
  69. }
  70. code, err := format.Source(buff.Bytes())
  71. if err != nil {
  72. return err
  73. }
  74. // Write file
  75. output, err := outputFile()
  76. if err != nil {
  77. return err
  78. }
  79. defer output.Close()
  80. _, err = io.Copy(output, bytes.NewReader(code))
  81. return err
  82. }
  83. func parseTemplates() (*template.Template, *template.Template, error) {
  84. tmplHead, err := template.New("header").Parse(headerTemplate)
  85. if err != nil {
  86. return nil, nil, err
  87. }
  88. if *tmplFile != "" {
  89. f, err := ioutil.ReadFile(*tmplFile)
  90. if err != nil {
  91. return nil, nil, err
  92. }
  93. funcTemplate = string(f)
  94. }
  95. tmpl, err := template.New("function").Parse(funcTemplate)
  96. if err != nil {
  97. return nil, nil, err
  98. }
  99. return tmplHead, tmpl, nil
  100. }
  101. func outputFile() (*os.File, error) {
  102. filename := *out
  103. if filename == "-" || (filename == "" && *tmplFile == "") {
  104. return os.Stdout, nil
  105. }
  106. if filename == "" {
  107. filename = strings.TrimSuffix(strings.TrimSuffix(*tmplFile, ".tmpl"), ".go") + ".go"
  108. }
  109. return os.Create(filename)
  110. }
  111. // analyzeCode takes the types scope and the docs and returns the import
  112. // information and information about all the assertion functions.
  113. func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []testFunc, error) {
  114. testingT := scope.Lookup("TestingT").Type().Underlying().(*types.Interface)
  115. importer := imports.New(*outputPkg)
  116. var funcs []testFunc
  117. // Go through all the top level functions
  118. for _, fdocs := range docs.Funcs {
  119. // Find the function
  120. obj := scope.Lookup(fdocs.Name)
  121. fn, ok := obj.(*types.Func)
  122. if !ok {
  123. continue
  124. }
  125. // Check function signature has at least two arguments
  126. sig := fn.Type().(*types.Signature)
  127. if sig.Params().Len() < 2 {
  128. continue
  129. }
  130. // Check first argument is of type testingT
  131. first, ok := sig.Params().At(0).Type().(*types.Named)
  132. if !ok {
  133. continue
  134. }
  135. firstType, ok := first.Underlying().(*types.Interface)
  136. if !ok {
  137. continue
  138. }
  139. if !types.Implements(firstType, testingT) {
  140. continue
  141. }
  142. // Skip functions ending with f
  143. if strings.HasSuffix(fdocs.Name, "f") && !*includeF {
  144. continue
  145. }
  146. funcs = append(funcs, testFunc{*outputPkg, fdocs, fn})
  147. importer.AddImportsFrom(sig.Params())
  148. }
  149. return importer, funcs, nil
  150. }
  151. // parsePackageSource returns the types scope and the package documentation from the package
  152. func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
  153. pd, err := build.Import(pkg, ".", 0)
  154. if err != nil {
  155. return nil, nil, err
  156. }
  157. fset := token.NewFileSet()
  158. files := make(map[string]*ast.File)
  159. fileList := make([]*ast.File, len(pd.GoFiles))
  160. for i, fname := range pd.GoFiles {
  161. src, err := ioutil.ReadFile(path.Join(pd.SrcRoot, pd.ImportPath, fname))
  162. if err != nil {
  163. return nil, nil, err
  164. }
  165. f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors)
  166. if err != nil {
  167. return nil, nil, err
  168. }
  169. files[fname] = f
  170. fileList[i] = f
  171. }
  172. cfg := types.Config{
  173. Importer: importer.Default(),
  174. }
  175. info := types.Info{
  176. Defs: make(map[*ast.Ident]types.Object),
  177. }
  178. tp, err := cfg.Check(pkg, fset, fileList, &info)
  179. if err != nil {
  180. return nil, nil, err
  181. }
  182. scope := tp.Scope()
  183. ap, _ := ast.NewPackage(fset, files, nil, nil)
  184. docs := doc.New(ap, pkg, 0)
  185. return scope, docs, nil
  186. }
  187. type testFunc struct {
  188. CurrentPkg string
  189. DocInfo *doc.Func
  190. TypeInfo *types.Func
  191. }
  192. func (f *testFunc) Qualifier(p *types.Package) string {
  193. if p == nil || p.Name() == f.CurrentPkg {
  194. return ""
  195. }
  196. return p.Name()
  197. }
  198. func (f *testFunc) Params() string {
  199. sig := f.TypeInfo.Type().(*types.Signature)
  200. params := sig.Params()
  201. p := ""
  202. comma := ""
  203. to := params.Len()
  204. var i int
  205. if sig.Variadic() {
  206. to--
  207. }
  208. for i = 1; i < to; i++ {
  209. param := params.At(i)
  210. p += fmt.Sprintf("%s%s %s", comma, param.Name(), types.TypeString(param.Type(), f.Qualifier))
  211. comma = ", "
  212. }
  213. if sig.Variadic() {
  214. param := params.At(params.Len() - 1)
  215. p += fmt.Sprintf("%s%s ...%s", comma, param.Name(), types.TypeString(param.Type().(*types.Slice).Elem(), f.Qualifier))
  216. }
  217. return p
  218. }
  219. func (f *testFunc) ForwardedParams() string {
  220. sig := f.TypeInfo.Type().(*types.Signature)
  221. params := sig.Params()
  222. p := ""
  223. comma := ""
  224. to := params.Len()
  225. var i int
  226. if sig.Variadic() {
  227. to--
  228. }
  229. for i = 1; i < to; i++ {
  230. param := params.At(i)
  231. p += fmt.Sprintf("%s%s", comma, param.Name())
  232. comma = ", "
  233. }
  234. if sig.Variadic() {
  235. param := params.At(params.Len() - 1)
  236. p += fmt.Sprintf("%s%s...", comma, param.Name())
  237. }
  238. return p
  239. }
  240. func (f *testFunc) ParamsFormat() string {
  241. return strings.Replace(f.Params(), "msgAndArgs", "msg string, args", 1)
  242. }
  243. func (f *testFunc) ForwardedParamsFormat() string {
  244. return strings.Replace(f.ForwardedParams(), "msgAndArgs", "append([]interface{}{msg}, args...)", 1)
  245. }
  246. func (f *testFunc) Comment() string {
  247. return "// " + strings.Replace(strings.TrimSpace(f.DocInfo.Doc), "\n", "\n// ", -1)
  248. }
  249. func (f *testFunc) CommentFormat() string {
  250. search := fmt.Sprintf("%s", f.DocInfo.Name)
  251. replace := fmt.Sprintf("%sf", f.DocInfo.Name)
  252. comment := strings.Replace(f.Comment(), search, replace, -1)
  253. exp := regexp.MustCompile(replace + `\(((\(\)|[^)])+)\)`)
  254. return exp.ReplaceAllString(comment, replace+`($1, "error message %s", "formatted")`)
  255. }
  256. func (f *testFunc) CommentWithoutT(receiver string) string {
  257. search := fmt.Sprintf("assert.%s(t, ", f.DocInfo.Name)
  258. replace := fmt.Sprintf("%s.%s(", receiver, f.DocInfo.Name)
  259. return strings.Replace(f.Comment(), search, replace, -1)
  260. }
  261. var headerTemplate = `/*
  262. * CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
  263. * THIS FILE MUST NOT BE EDITED BY HAND
  264. */
  265. package {{.Name}}
  266. import (
  267. {{range $path, $name := .Imports}}
  268. {{$name}} "{{$path}}"{{end}}
  269. )
  270. `
  271. var funcTemplate = `{{.Comment}}
  272. func (fwd *AssertionsForwarder) {{.DocInfo.Name}}({{.Params}}) bool {
  273. return assert.{{.DocInfo.Name}}({{.ForwardedParams}})
  274. }`