You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
3.6 KiB

2 years ago
  1. // Copyright 2018 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package builder
  5. import (
  6. sql2 "database/sql"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. )
  12. func condToSQL(cond Cond) (string, []interface{}, error) {
  13. if cond == nil || !cond.IsValid() {
  14. return "", nil, nil
  15. }
  16. w := NewWriter()
  17. if err := cond.WriteTo(w); err != nil {
  18. return "", nil, err
  19. }
  20. return w.String(), w.args, nil
  21. }
  22. func condToBoundSQL(cond Cond) (string, error) {
  23. if cond == nil || !cond.IsValid() {
  24. return "", nil
  25. }
  26. w := NewWriter()
  27. if err := cond.WriteTo(w); err != nil {
  28. return "", err
  29. }
  30. return ConvertToBoundSQL(w.String(), w.args)
  31. }
  32. // ToSQL convert a builder or conditions to SQL and args
  33. func ToSQL(cond interface{}) (string, []interface{}, error) {
  34. switch cond.(type) {
  35. case Cond:
  36. return condToSQL(cond.(Cond))
  37. case *Builder:
  38. return cond.(*Builder).ToSQL()
  39. }
  40. return "", nil, ErrNotSupportType
  41. }
  42. // ToBoundSQL convert a builder or conditions to parameters bound SQL
  43. func ToBoundSQL(cond interface{}) (string, error) {
  44. switch cond.(type) {
  45. case Cond:
  46. return condToBoundSQL(cond.(Cond))
  47. case *Builder:
  48. return cond.(*Builder).ToBoundSQL()
  49. }
  50. return "", ErrNotSupportType
  51. }
  52. func noSQLQuoteNeeded(a interface{}) bool {
  53. if a == nil {
  54. return false
  55. }
  56. switch a.(type) {
  57. case Field:
  58. return true
  59. case int, int8, int16, int32, int64:
  60. return true
  61. case uint, uint8, uint16, uint32, uint64:
  62. return true
  63. case float32, float64:
  64. return true
  65. case bool:
  66. return true
  67. case string:
  68. return false
  69. case time.Time, *time.Time:
  70. return false
  71. }
  72. t := reflect.TypeOf(a)
  73. switch t.Kind() {
  74. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  75. return true
  76. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  77. return true
  78. case reflect.Float32, reflect.Float64:
  79. return true
  80. case reflect.Bool:
  81. return true
  82. case reflect.String:
  83. return false
  84. }
  85. return false
  86. }
  87. // ConvertToBoundSQL will convert SQL and args to a bound SQL
  88. func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
  89. buf := strings.Builder{}
  90. var i, j, start int
  91. for ; i < len(sql); i++ {
  92. if sql[i] == '?' {
  93. _, err := buf.WriteString(sql[start:i])
  94. if err != nil {
  95. return "", err
  96. }
  97. start = i + 1
  98. if len(args) == j {
  99. return "", ErrNeedMoreArguments
  100. }
  101. arg := args[j]
  102. if namedArg, ok := arg.(sql2.NamedArg); ok {
  103. arg = namedArg.Value
  104. }
  105. if noSQLQuoteNeeded(arg) {
  106. _, err = fmt.Fprint(&buf, arg)
  107. } else {
  108. // replace ' -> '' (standard replacement) to avoid critical SQL injection,
  109. // NOTICE: may allow some injection like % (or _) in LIKE query
  110. _, err = fmt.Fprintf(&buf, "'%v'", strings.Replace(fmt.Sprintf("%v", arg), "'",
  111. "''", -1))
  112. }
  113. if err != nil {
  114. return "", err
  115. }
  116. j = j + 1
  117. }
  118. }
  119. _, err := buf.WriteString(sql[start:])
  120. if err != nil {
  121. return "", err
  122. }
  123. return buf.String(), nil
  124. }
  125. // ConvertPlaceholder replaces the place holder ? to $1, $2 ... or :1, :2 ... according prefix
  126. func ConvertPlaceholder(sql, prefix string) (string, error) {
  127. buf := strings.Builder{}
  128. var i, j, start int
  129. var ready = true
  130. for ; i < len(sql); i++ {
  131. if sql[i] == '\'' && i > 0 && sql[i-1] != '\\' {
  132. ready = !ready
  133. }
  134. if ready && sql[i] == '?' {
  135. if _, err := buf.WriteString(sql[start:i]); err != nil {
  136. return "", err
  137. }
  138. start = i + 1
  139. j = j + 1
  140. if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil {
  141. return "", err
  142. }
  143. }
  144. }
  145. if _, err := buf.WriteString(sql[start:]); err != nil {
  146. return "", err
  147. }
  148. return buf.String(), nil
  149. }