diff --git a/.gitignore b/.gitignore index f4d432a..66fd13c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -# ---> Go # Binaries for programs and plugins *.exe *.exe~ @@ -14,4 +13,3 @@ # Dependency directories (remove the comment below to include it) # vendor/ - diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..bafa23f --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "commentTranslate.source": "DarkCWK.youdao-youdao" +} \ No newline at end of file diff --git a/README.md b/README.md index 60d7e12..cdab7fe 100644 --- a/README.md +++ b/README.md @@ -1,2 +1 @@ -# virtualsql - +# vql \ No newline at end of file diff --git a/builder/builder.go b/builder/builder.go new file mode 100644 index 0000000..10b6076 --- /dev/null +++ b/builder/builder.go @@ -0,0 +1,322 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + sql2 "database/sql" + "fmt" +) + +type optype byte + +const ( + condType optype = iota // only conditions + selectType // select + insertType // insert + updateType // update + deleteType // delete + setOpType // set operation +) + +// all databasees +const ( + POSTGRES = "postgres" + SQLITE = "sqlite3" + MYSQL = "mysql" + MSSQL = "mssql" + ORACLE = "oracle" + + UNION = "union" + INTERSECT = "intersect" + EXCEPT = "except" +) + +type join struct { + joinType string + joinTable interface{} + joinCond Cond + joinAlias string +} + +type setOp struct { + opType string + distinctType string + builder *Builder +} + +type limit struct { + limitN int + offset int +} + +// Builder describes a SQL statement +type Builder struct { + optype + dialect string + isNested bool + into string + from string + subQuery *Builder + cond Cond + selects []string + joins []join + setOps []setOp + limitation *limit + insertCols []string + insertVals []interface{} + updates []UpdateCond + orderBy string + groupBy string + having string +} + +// Dialect sets the db dialect of Builder. +func Dialect(dialect string) *Builder { + builder := &Builder{cond: NewCond(), dialect: dialect} + return builder +} + +// MySQL is shortcut of Dialect(MySQL) +func MySQL() *Builder { + return Dialect(MYSQL) +} + +// MsSQL is shortcut of Dialect(MsSQL) +func MsSQL() *Builder { + return Dialect(MSSQL) +} + +// Oracle is shortcut of Dialect(Oracle) +func Oracle() *Builder { + return Dialect(ORACLE) +} + +// Postgres is shortcut of Dialect(Postgres) +func Postgres() *Builder { + return Dialect(POSTGRES) +} + +// SQLite is shortcut of Dialect(SQLITE) +func SQLite() *Builder { + return Dialect(SQLITE) +} + +// Where sets where SQL +func (b *Builder) Where(cond Cond) *Builder { + if b.cond.IsValid() { + b.cond = b.cond.And(cond) + } else { + b.cond = cond + } + return b +} + +// From sets from subject(can be a table name in string or a builder pointer) and its alias +func (b *Builder) From(subject interface{}, alias ...string) *Builder { + switch subject.(type) { + case *Builder: + b.subQuery = subject.(*Builder) + + if len(alias) > 0 { + b.from = alias[0] + } else { + b.isNested = true + } + case string: + b.from = subject.(string) + + if len(alias) > 0 { + b.from = b.from + " " + alias[0] + } + } + + return b +} + +// TableName returns the table name +func (b *Builder) TableName() string { + if b.optype == insertType { + return b.into + } + return b.from +} + +// Into sets insert table name +func (b *Builder) Into(tableName string) *Builder { + b.into = tableName + return b +} + +// Union sets union conditions +func (b *Builder) Union(distinctType string, cond *Builder) *Builder { + return b.setOperation(UNION, distinctType, cond) +} + +// Intersect sets intersect conditions +func (b *Builder) Intersect(distinctType string, cond *Builder) *Builder { + return b.setOperation(INTERSECT, distinctType, cond) +} + +// Except sets except conditions +func (b *Builder) Except(distinctType string, cond *Builder) *Builder { + return b.setOperation(EXCEPT, distinctType, cond) +} + +func (b *Builder) setOperation(opType, distinctType string, cond *Builder) *Builder { + + var builder *Builder + if b.optype != setOpType { + builder = &Builder{cond: NewCond()} + builder.optype = setOpType + builder.dialect = b.dialect + builder.selects = b.selects + + currentSetOps := b.setOps + // erase sub setOps (actually append to new Builder.unions) + b.setOps = nil + + for e := range currentSetOps { + currentSetOps[e].builder.dialect = b.dialect + } + + builder.setOps = append(append(builder.setOps, setOp{opType, "", b}), currentSetOps...) + } else { + builder = b + } + + if cond != nil { + if cond.dialect == "" && builder.dialect != "" { + cond.dialect = builder.dialect + } + + builder.setOps = append(builder.setOps, setOp{opType, distinctType, cond}) + } + + return builder +} + +// Limit sets limitN condition +func (b *Builder) Limit(limitN int, offset ...int) *Builder { + b.limitation = &limit{limitN: limitN} + + if len(offset) > 0 { + b.limitation.offset = offset[0] + } + + return b +} + +// Select sets select SQL +func (b *Builder) Select(cols ...string) *Builder { + b.selects = cols + if b.optype == condType { + b.optype = selectType + } + return b +} + +// And sets AND condition +func (b *Builder) And(cond Cond) *Builder { + b.cond = And(b.cond, cond) + return b +} + +// Or sets OR condition +func (b *Builder) Or(cond Cond) *Builder { + b.cond = Or(b.cond, cond) + return b +} + +// Update sets update SQL +func (b *Builder) Update(updates ...Cond) *Builder { + b.updates = make([]UpdateCond, 0, len(updates)) + for _, update := range updates { + if u, ok := update.(UpdateCond); ok && u.IsValid() { + b.updates = append(b.updates, u) + } + } + b.optype = updateType + return b +} + +// Delete sets delete SQL +func (b *Builder) Delete(conds ...Cond) *Builder { + b.cond = b.cond.And(conds...) + b.optype = deleteType + return b +} + +// WriteTo implements Writer interface +func (b *Builder) WriteTo(w Writer) error { + switch b.optype { + /*case condType: + return b.cond.WriteTo(w)*/ + case selectType: + return b.selectWriteTo(w) + case insertType: + return b.insertWriteTo(w) + case updateType: + return b.updateWriteTo(w) + case deleteType: + return b.deleteWriteTo(w) + case setOpType: + return b.setOpWriteTo(w) + } + + return ErrNotSupportType +} + +// ToSQL convert a builder to SQL and args +func (b *Builder) ToSQL() (string, []interface{}, error) { + w := NewWriter() + if err := b.WriteTo(w); err != nil { + return "", nil, err + } + + // in case of sql.NamedArg in args + for e := range w.args { + if namedArg, ok := w.args[e].(sql2.NamedArg); ok { + w.args[e] = namedArg.Value + } + } + + var sql = w.String() + var err error + + switch b.dialect { + case ORACLE, MSSQL: + // This is for compatibility with different sql drivers + for e := range w.args { + w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e]) + } + + var prefix string + if b.dialect == ORACLE { + prefix = ":p" + } else { + prefix = "@p" + } + + if sql, err = ConvertPlaceholder(sql, prefix); err != nil { + return "", nil, err + } + case POSTGRES: + if sql, err = ConvertPlaceholder(sql, "$"); err != nil { + return "", nil, err + } + } + + return sql, w.args, nil +} + +// ToBoundSQL generated a bound SQL string +func (b *Builder) ToBoundSQL() (string, error) { + w := NewWriter() + if err := b.WriteTo(w); err != nil { + return "", err + } + + return ConvertToBoundSQL(w.String(), w.args) +} diff --git a/builder/builder_b_test.go b/builder/builder_b_test.go new file mode 100644 index 0000000..6bab9a6 --- /dev/null +++ b/builder/builder_b_test.go @@ -0,0 +1,298 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "math/rand" + "testing" +) + +type randGenConf struct { + allowCond bool + allowJoin bool + allowLimit bool + allowUnion bool + allowHaving bool + allowGroupBy bool + allowOrderBy bool + allowSubQuery bool +} + +var expectedValues = []interface{}{ + "dangerous", "fun", "degree", "hospital", "horseshoe", "summit", "parallel", "height", "recommend", "invite", + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +var queryFields = []string{"f1", "f2", "f2", "f4", "f5", "f6", "f7", "f8", "f9"} + +func BenchmarkSelect_Simple(b *testing.B) { + rgc := randGenConf{allowCond: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery("", &rgc).ToSQL() + } +} + +func BenchmarkSelect_SubQuery(b *testing.B) { + rgc := randGenConf{allowSubQuery: true, allowCond: true, allowGroupBy: true, allowHaving: true, allowOrderBy: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery("", &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectConditional4Oracle(b *testing.B) { + rgc := randGenConf{allowLimit: true, allowCond: true, allowGroupBy: true, allowHaving: true, allowOrderBy: true} + for i := 0; i < b.N; i++ { + randQuery(ORACLE, &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectConditional4Mssql(b *testing.B) { + rgc := randGenConf{allowLimit: true, allowCond: true, allowGroupBy: true, allowHaving: true, allowOrderBy: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery(MSSQL, &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectConditional4MysqlLike(b *testing.B) { + rgc := randGenConf{allowLimit: true, allowCond: true, allowGroupBy: true, allowHaving: true, allowOrderBy: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery(MYSQL, &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectConditional4Mixed(b *testing.B) { + rgc := randGenConf{allowLimit: true, allowCond: true, allowGroupBy: true, allowHaving: true, allowOrderBy: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery(randDialect(), &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectComplex4Oracle(b *testing.B) { + rgc := randGenConf{ + allowLimit: true, allowCond: true, + allowGroupBy: true, allowHaving: true, + allowOrderBy: true, allowSubQuery: true, + } + for i := 0; i < b.N; i++ { + randQuery(ORACLE, &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectComplex4Mssql(b *testing.B) { + rgc := randGenConf{ + allowLimit: true, allowCond: true, + allowGroupBy: true, allowHaving: true, + allowOrderBy: true, allowSubQuery: true, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery(MSSQL, &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectComplex4MysqlLike(b *testing.B) { + rgc := randGenConf{ + allowLimit: true, allowCond: true, + allowGroupBy: true, allowHaving: true, + allowOrderBy: true, allowSubQuery: true, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery(MYSQL, &rgc).ToSQL() + } +} + +func BenchmarkSelect_SelectComplex4MysqlMixed(b *testing.B) { + rgc := randGenConf{ + allowLimit: true, allowCond: true, + allowGroupBy: true, allowHaving: true, + allowOrderBy: true, allowSubQuery: true, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + randQuery(randDialect(), &rgc).ToSQL() + } +} + +func BenchmarkInsert(b *testing.B) { + rgc := randGenConf{allowCond: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randInsertByCondition(&rgc).ToSQL() + } +} + +func BenchmarkUpdate(b *testing.B) { + rgc := randGenConf{allowCond: true} + b.ResetTimer() + for i := 0; i < b.N; i++ { + randUpdateByCondition(&rgc).ToSQL() + } +} + +// randQuery Generate a basic query for benchmark test. But be careful it's not a executable SQL in real db. +func randQuery(dialect string, rgc *randGenConf) *Builder { + b := randSelectByCondition(dialect, rgc) + isUnionized := rgc.allowUnion && rand.Intn(1000) >= 500 + if isUnionized { + r := rand.Intn(3) + 1 + for i := r; i < r; i++ { + b = b.Union("all", randSelectByCondition(dialect, rgc)) + } + } + + if isUnionized && rgc.allowLimit && rand.Intn(1000) >= 500 { + b = randLimit(Dialect(dialect).Select().From(b, "t")) + } + + return b +} + +func randInsertByCondition(rgc *randGenConf) *Builder { + fields := randSelects() + + times := rand.Intn(10) + 1 + + eqs := Eq{} + for i := 0; i < times; i++ { + eqs[fields[rand.Intn(len(fields))]] = "expected" + } + + b := Insert(eqs).From("table1") + + if rgc.allowCond && rand.Intn(1000) >= 500 { + b = b.Where(randCond(b.selects, 3)) + } + + return b +} + +func randUpdateByCondition(rgc *randGenConf) *Builder { + fields := randSelects() + + times := rand.Intn(10) + 1 + + eqs := Eq{} + for i := 0; i < times; i++ { + eqs[fields[rand.Intn(len(fields))]] = randVal() + } + + b := Update(eqs).From("table1") + + if rgc.allowCond && rand.Intn(1000) >= 500 { + b.Where(randCond(fields, 3)) + } + + return b +} + +func randSelectByCondition(dialect string, rgc *randGenConf) *Builder { + var b *Builder + if rgc.allowSubQuery { + cpRgc := *rgc + cpRgc.allowSubQuery = false + b = Dialect(dialect).Select(randSelects()...).From(randQuery(dialect, &cpRgc), randTableName(0)) + } else { + b = Dialect(dialect).Select(randSelects()...).From(randTableName(0)) + } + if rgc.allowJoin { + b = randJoin(b, 3) + } + if rgc.allowCond && rand.Intn(1000) >= 500 { + b = b.Where(randCond(b.selects, 3)) + } + if rgc.allowLimit && rand.Intn(1000) >= 500 { + b = randLimit(b) + } + if rgc.allowOrderBy && rand.Intn(1000) >= 500 { + b = randOrderBy(b) + } + if rgc.allowHaving && rand.Intn(1000) >= 500 { + b = randHaving(b) + } + if rgc.allowGroupBy && rand.Intn(1000) >= 500 { + b = randGroupBy(b) + } + + return b +} + +func randDialect() string { + dialects := []string{MYSQL, ORACLE, MSSQL, SQLITE, POSTGRES} + + return dialects[rand.Intn(len(dialects))] +} + +func randSelects() []string { + if rand.Intn(1000) > 900 { + return []string{"*"} + } + + rdx := rand.Intn(len(queryFields) / 2) + return queryFields[rdx:] +} + +func randTableName(offset int) string { + return fmt.Sprintf("table%v", rand.Intn(10)+offset) +} + +func randJoin(b *Builder, lessThan int) *Builder { + if lessThan <= 0 { + return b + } + + times := rand.Intn(lessThan) + + for i := 0; i < times; i++ { + tableName := randTableName(i * 10) + b = b.Join("", tableName, fmt.Sprintf("%v.id = %v.id", b.TableName(), tableName)) + } + + return b +} + +func randCond(selects []string, lessThan int) Cond { + if len(selects) <= 0 { + return nil + } + + cond := NewCond() + + times := rand.Intn(lessThan) + for i := 0; i < times; i++ { + cond = cond.And(Eq{selects[rand.Intn(len(selects))]: randVal()}) + } + + return cond +} + +func randLimit(b *Builder) *Builder { + r := rand.Intn(1000) + 1 + if r > 500 { + return b.Limit(r, 1000) + } + return b.Limit(r) +} + +func randOrderBy(b *Builder) *Builder { + return b.OrderBy(fmt.Sprintf("%v ASC", b.selects[rand.Intn(len(b.selects))])) +} + +func randHaving(b *Builder) *Builder { + return b.OrderBy(fmt.Sprintf("%v = %v", b.selects[rand.Intn(len(b.selects))], randVal())) +} + +func randGroupBy(b *Builder) *Builder { + return b.GroupBy(fmt.Sprintf("%v = %v", b.selects[rand.Intn(len(b.selects))], randVal())) +} + +func randVal() interface{} { + return expectedValues[rand.Intn(len(expectedValues))] +} diff --git a/builder/builder_delete.go b/builder/builder_delete.go new file mode 100644 index 0000000..317cc3f --- /dev/null +++ b/builder/builder_delete.go @@ -0,0 +1,27 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" +) + +// Delete creates a delete Builder +func Delete(conds ...Cond) *Builder { + builder := &Builder{cond: NewCond()} + return builder.Delete(conds...) +} + +func (b *Builder) deleteWriteTo(w Writer) error { + if len(b.from) <= 0 { + return ErrNoTableName + } + + if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil { + return err + } + + return b.cond.WriteTo(w) +} diff --git a/builder/builder_delete_test.go b/builder/builder_delete_test.go new file mode 100644 index 0000000..7b40498 --- /dev/null +++ b/builder/builder_delete_test.go @@ -0,0 +1,24 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuilderDelete(t *testing.T) { + sql, args, err := Delete(Eq{"a": 1}).From("table1").ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "DELETE FROM table1 WHERE a=?", sql) + assert.EqualValues(t, []interface{}{1}, args) +} + +func TestDeleteNoTable(t *testing.T) { + _, _, err := Delete(Eq{"b": "0"}).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNoTableName, err) +} diff --git a/builder/builder_insert.go b/builder/builder_insert.go new file mode 100644 index 0000000..8cef5c5 --- /dev/null +++ b/builder/builder_insert.go @@ -0,0 +1,149 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "bytes" + "fmt" + "sort" +) + +// Insert creates an insert Builder +func Insert(eq ...interface{}) *Builder { + builder := &Builder{cond: NewCond()} + return builder.Insert(eq...) +} + +func (b *Builder) insertSelectWriteTo(w Writer) error { + if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil { + return err + } + + if len(b.insertCols) > 0 { + fmt.Fprintf(w, "(") + for _, col := range b.insertCols { + fmt.Fprintf(w, col) + } + fmt.Fprintf(w, ") ") + } + + return b.selectWriteTo(w) +} + +func (b *Builder) insertWriteTo(w Writer) error { + if len(b.into) <= 0 { + return ErrNoTableName + } + if len(b.insertCols) <= 0 && b.from == "" { + return ErrNoColumnToInsert + } + + if b.into != "" && b.from != "" { + return b.insertSelectWriteTo(w) + } + + if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil { + return err + } + + var args = make([]interface{}, 0) + var bs []byte + var valBuffer = bytes.NewBuffer(bs) + + for i, col := range b.insertCols { + value := b.insertVals[i] + fmt.Fprint(w, col) + if e, ok := value.(expr); ok { + fmt.Fprintf(valBuffer, "(%s)", e.sql) + args = append(args, e.args...) + } else if value == nil { + fmt.Fprintf(valBuffer, `null`) + } else { + fmt.Fprint(valBuffer, "?") + args = append(args, value) + } + + if i != len(b.insertCols)-1 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + if _, err := fmt.Fprint(valBuffer, ","); err != nil { + return err + } + } + } + + if _, err := fmt.Fprint(w, ") Values ("); err != nil { + return err + } + + if _, err := w.Write(valBuffer.Bytes()); err != nil { + return err + } + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + + w.Append(args...) + + return nil +} + +type insertColsSorter struct { + cols []string + vals []interface{} +} + +func (s insertColsSorter) Len() int { + return len(s.cols) +} + +func (s insertColsSorter) Swap(i, j int) { + s.cols[i], s.cols[j] = s.cols[j], s.cols[i] + s.vals[i], s.vals[j] = s.vals[j], s.vals[i] +} + +func (s insertColsSorter) Less(i, j int) bool { + return s.cols[i] < s.cols[j] +} + +// Insert sets insert SQL +func (b *Builder) Insert(eq ...interface{}) *Builder { + if len(eq) > 0 { + var paramType = -1 + for _, e := range eq { + switch t := e.(type) { + case Eq: + if paramType == -1 { + paramType = 0 + } + if paramType != 0 { + break + } + for k, v := range t { + b.insertCols = append(b.insertCols, k) + b.insertVals = append(b.insertVals, v) + } + case string: + if paramType == -1 { + paramType = 1 + } + if paramType != 1 { + break + } + b.insertCols = append(b.insertCols, t) + } + } + } + + if len(b.insertCols) == len(b.insertVals) { + sort.Sort(insertColsSorter{ + cols: b.insertCols, + vals: b.insertVals, + }) + } + b.optype = insertType + return b +} diff --git a/builder/builder_insert_test.go b/builder/builder_insert_test.go new file mode 100644 index 0000000..bb7305b --- /dev/null +++ b/builder/builder_insert_test.go @@ -0,0 +1,54 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuilderInsert(t *testing.T) { + sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql) + + sql, err = Insert(Eq{"e": 3}, Eq{"c": 1}, Eq{"d": 2}).Into("table1").ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, "INSERT INTO table1 (c,d,e) Values (1,2,3)", sql) + + sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql) + + sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNoTableName, err) + assert.EqualValues(t, "", sql) + + sql, err = Insert(Eq{}).Into("table1").ToBoundSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNoColumnToInsert, err) + assert.EqualValues(t, "", sql) + + sql, err = Insert(Eq{`a`: nil}).Into(`table1`).ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, `INSERT INTO table1 (a) Values (null)`, sql) + + sql, args, err := Insert(Eq{`a`: nil, `b`: `str`}).Into(`table1`).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, `INSERT INTO table1 (a,b) Values (null,?)`, sql) + assert.EqualValues(t, []interface{}{`str`}, args) +} + +func TestBuidlerInsert_Select(t *testing.T) { + sql, err := Insert().Into("table1").Select().From("table2").ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, "INSERT INTO table1 SELECT * FROM table2", sql) + + sql, err = Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, "INSERT INTO table1 (a, b) SELECT b, c FROM table2", sql) +} diff --git a/builder/builder_join.go b/builder/builder_join.go new file mode 100644 index 0000000..6b89edc --- /dev/null +++ b/builder/builder_join.go @@ -0,0 +1,56 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +// InnerJoin sets inner join +func (b *Builder) InnerJoin(joinTable, joinCond interface{}) *Builder { + return b.Join("INNER", joinTable, joinCond) +} + +// LeftJoin sets left join SQL +func (b *Builder) LeftJoin(joinTable, joinCond interface{}) *Builder { + return b.Join("LEFT", joinTable, joinCond) +} + +// RightJoin sets right join SQL +func (b *Builder) RightJoin(joinTable, joinCond interface{}) *Builder { + return b.Join("RIGHT", joinTable, joinCond) +} + +// CrossJoin sets cross join SQL +func (b *Builder) CrossJoin(joinTable, joinCond interface{}) *Builder { + return b.Join("CROSS", joinTable, joinCond) +} + +// FullJoin sets full join SQL +func (b *Builder) FullJoin(joinTable, joinCond interface{}) *Builder { + return b.Join("FULL", joinTable, joinCond) +} + +// Join sets join table and conditions +func (b *Builder) Join(joinType string, joinTable, joinCond interface{}, alias ...string) *Builder { + aliasName := "" + if len(alias) > 0 { + aliasName = alias[0] + } + switch joinCond.(type) { + case Cond: + b.joins = append(b.joins, join{ + joinType: joinType, + joinTable: joinTable, + joinCond: joinCond.(Cond), + joinAlias: aliasName, + }) + case string: + b.joins = append(b.joins, join{ + joinType: joinType, + joinTable: joinTable, + joinCond: Expr(joinCond.(string)), + joinAlias: aliasName, + }) + } + + return b +} diff --git a/builder/builder_join_test.go b/builder/builder_join_test.go new file mode 100644 index 0000000..b61c484 --- /dev/null +++ b/builder/builder_join_test.go @@ -0,0 +1,50 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestJoin(t *testing.T) { + sql, args, err := Select("c, d").From("table1").LeftJoin("table2", Eq{"table1.id": 1}.And(Lt{"table2.id": 3})). + RightJoin("table3", "table2.id = table3.tid").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT c, d FROM table1 LEFT JOIN table2 ON table1.id=? AND table2.id?) ON table1.id=? AND table2.id?) ON table2.id = table3.tid WHERE a=?", + sql) + assert.EqualValues(t, []interface{}{1, 1, 3, "2", 1}, args) +} diff --git a/builder/builder_limit.go b/builder/builder_limit.go new file mode 100644 index 0000000..82e1179 --- /dev/null +++ b/builder/builder_limit.go @@ -0,0 +1,103 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "strings" +) + +func (b *Builder) limitWriteTo(w Writer) error { + if strings.TrimSpace(b.dialect) == "" { + return ErrDialectNotSetUp + } + + if b.limitation != nil { + limit := b.limitation + if limit.offset < 0 || limit.limitN <= 0 { + return ErrInvalidLimitation + } + // erase limit condition + b.limitation = nil + defer func() { + b.limitation = limit + }() + ow := w.(*BytesWriter) + + switch strings.ToLower(strings.TrimSpace(b.dialect)) { + case ORACLE: + if len(b.selects) == 0 { + b.selects = append(b.selects, "*") + } + + var final *Builder + selects := b.selects + b.selects = append(selects, "ROWNUM RN") + + var wb *Builder + if b.optype == setOpType { + wb = Dialect(b.dialect).Select("at.*", "ROWNUM RN"). + From(b, "at") + } else { + wb = b + } + + if limit.offset == 0 { + final = Dialect(b.dialect).Select(selects...).From(wb, "at"). + Where(Lte{"at.RN": limit.limitN}) + } else { + sub := Dialect(b.dialect).Select("*"). + From(b, "at").Where(Lte{"at.RN": limit.offset + limit.limitN}) + + final = Dialect(b.dialect).Select(selects...).From(sub, "att"). + Where(Gt{"att.RN": limit.offset}) + } + + return final.WriteTo(ow) + case SQLITE, MYSQL, POSTGRES: + // if type UNION, we need to write previous content back to current writer + if b.optype == setOpType { + if err := b.WriteTo(ow); err != nil { + return err + } + } + + if limit.offset == 0 { + fmt.Fprint(ow, " LIMIT ", limit.limitN) + } else { + fmt.Fprintf(ow, " LIMIT %v OFFSET %v", limit.limitN, limit.offset) + } + case MSSQL: + if len(b.selects) == 0 { + b.selects = append(b.selects, "*") + } + + var final *Builder + selects := b.selects + b.selects = append(append([]string{fmt.Sprintf("TOP %d %v", limit.limitN+limit.offset, b.selects[0])}, + b.selects[1:]...), "ROW_NUMBER() OVER (ORDER BY (SELECT 1)) AS RN") + + var wb *Builder + if b.optype == setOpType { + wb = Dialect(b.dialect).Select("*", "ROW_NUMBER() OVER (ORDER BY (SELECT 1)) AS RN"). + From(b, "at") + } else { + wb = b + } + + if limit.offset == 0 { + final = Dialect(b.dialect).Select(selects...).From(wb, "at") + } else { + final = Dialect(b.dialect).Select(selects...).From(wb, "at").Where(Gt{"at.RN": limit.offset}) + } + + return final.WriteTo(ow) + default: + return ErrNotSupportType + } + } + + return nil +} diff --git a/builder/builder_select.go b/builder/builder_select.go new file mode 100644 index 0000000..0bb6813 --- /dev/null +++ b/builder/builder_select.go @@ -0,0 +1,171 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" +) + +// Select creates a select Builder +func Select(cols ...string) *Builder { + builder := &Builder{cond: NewCond()} + return builder.Select(cols...) +} + +func (b *Builder) selectWriteTo(w Writer) error { + if len(b.from) <= 0 && !b.isNested { + return ErrNoTableName + } + + // perform limit before writing to writer when b.dialect between ORACLE and MSSQL + // this avoid a duplicate writing problem in simple limit query + if b.limitation != nil && (b.dialect == ORACLE || b.dialect == MSSQL) { + return b.limitWriteTo(w) + } + + if _, err := fmt.Fprint(w, "SELECT "); err != nil { + return err + } + if len(b.selects) > 0 { + for i, s := range b.selects { + if _, err := fmt.Fprint(w, s); err != nil { + return err + } + if i != len(b.selects)-1 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + } + } else { + if _, err := fmt.Fprint(w, "*"); err != nil { + return err + } + } + + if b.subQuery == nil { + if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil { + return err + } + } else { + if b.cond.IsValid() && len(b.from) <= 0 { + return ErrUnnamedDerivedTable + } + if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect { + return ErrInconsistentDialect + } + + // dialect of sub-query will inherit from the main one (if not set up) + if b.dialect != "" && b.subQuery.dialect == "" { + b.subQuery.dialect = b.dialect + } + + switch b.subQuery.optype { + case selectType, setOpType: + fmt.Fprint(w, " FROM (") + if err := b.subQuery.WriteTo(w); err != nil { + return err + } + + if len(b.from) == 0 { + fmt.Fprintf(w, ")") + } else { + fmt.Fprintf(w, ") %v", b.from) + } + default: + return ErrUnexpectedSubQuery + } + } + + for _, v := range b.joins { + b, ok := v.joinTable.(*Builder) + if ok { + if _, err := fmt.Fprintf(w, " %s JOIN (", v.joinType); err != nil { + return err + } + if err := b.WriteTo(w); err != nil { + return err + } + if v.joinAlias == "" { + if _, err := fmt.Fprintf(w, ") ON "); err != nil { + return err + } + } else { + if _, err := fmt.Fprintf(w, ") AS %s ON ", v.joinAlias); err != nil { + return err + } + } + } else { + + if v.joinAlias == "" { + if _, err := fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable); err != nil { + return err + } + } else { + if _, err := fmt.Fprintf(w, " %s JOIN %s AS %s ON ", v.joinType, v.joinTable, v.joinAlias); err != nil { + return err + } + } + } + + if err := v.joinCond.WriteTo(w); err != nil { + return err + } + } + + if b.cond.IsValid() { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + + if err := b.cond.WriteTo(w); err != nil { + return err + } + } + + if len(b.groupBy) > 0 { + if _, err := fmt.Fprint(w, " GROUP BY ", b.groupBy); err != nil { + return err + } + } + + if len(b.having) > 0 { + if _, err := fmt.Fprint(w, " HAVING ", b.having); err != nil { + return err + } + } + + if len(b.orderBy) > 0 { + if _, err := fmt.Fprint(w, " ORDER BY ", b.orderBy); err != nil { + return err + } + } + + if b.limitation != nil { + if err := b.limitWriteTo(w); err != nil { + return err + } + } + + return nil +} + +// OrderBy orderBy SQL +func (b *Builder) OrderBy(orderBy string) *Builder { + b.orderBy = orderBy + return b +} + +// GroupBy groupby SQL +func (b *Builder) GroupBy(groupby string) *Builder { + b.groupBy = groupby + return b +} + +// Having having SQL +func (b *Builder) Having(having string) *Builder { + b.having = having + return b +} diff --git a/builder/builder_select_test.go b/builder/builder_select_test.go new file mode 100644 index 0000000..fc100e6 --- /dev/null +++ b/builder/builder_select_test.go @@ -0,0 +1,115 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuilder_Select(t *testing.T) { + sql, args, err := Select("c, d").From("table1").ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT c, d FROM table1", sql) + assert.EqualValues(t, []interface{}(nil), args) + + sql, args, err = Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT c, d FROM table1 WHERE a=?", sql) + assert.EqualValues(t, []interface{}{1}, args) + + _, _, err = Select("c, d").ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNoTableName, err) +} + +func TestBuilderSelectGroupBy(t *testing.T) { + sql, args, err := Select("c").From("table1").GroupBy("c").Having("count(c)=1").ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT c FROM table1 GROUP BY c HAVING count(c)=1", sql) + assert.EqualValues(t, 0, len(args)) + fmt.Println(sql, args) +} + +func TestBuilderSelectOrderBy(t *testing.T) { + sql, args, err := Select("c").From("table1").OrderBy("c DESC").ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT c FROM table1 ORDER BY c DESC", sql) + assert.EqualValues(t, 0, len(args)) + fmt.Println(sql, args) +} + +func TestBuilder_From(t *testing.T) { + // simple one + sql, args, err := Select("c").From("table1").ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT c FROM table1", sql) + assert.EqualValues(t, 0, len(args)) + + // from sub with alias + sql, args, err = Select("sub.id").From(Select("id").From("table1").Where(Eq{"a": 1}), + "sub").Where(Eq{"b": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT sub.id FROM (SELECT id FROM table1 WHERE a=?) sub WHERE b=?", sql) + assert.EqualValues(t, []interface{}{1, 1}, args) + + // from sub without alias and with conditions + sql, args, err = Select("sub.id").From(Select("id").From("table1").Where(Eq{"a": 1})).Where(Eq{"b": 1}).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnnamedDerivedTable, err) + + // from sub without alias and conditions + sql, args, err = Select("sub.id").From(Select("id").From("table1").Where(Eq{"a": 1})).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT sub.id FROM (SELECT id FROM table1 WHERE a=?)", sql) + assert.EqualValues(t, []interface{}{1}, args) + + // from union with alias + sql, args, err = Select("sub.id").From( + Select("id").From("table1").Where(Eq{"a": 1}).Union( + "all", Select("id").From("table1").Where(Eq{"a": 2})), "sub").Where(Eq{"b": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT sub.id FROM ((SELECT id FROM table1 WHERE a=?) UNION ALL (SELECT id FROM table1 WHERE a=?)) sub WHERE b=?", sql) + assert.EqualValues(t, []interface{}{1, 2, 1}, args) + + // from union without alias + _, _, err = Select("sub.id").From( + Select("id").From("table1").Where(Eq{"a": 1}).Union( + "all", Select("id").From("table1").Where(Eq{"a": 2}))).Where(Eq{"b": 1}).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnnamedDerivedTable, err) + + // will raise error + _, _, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnexpectedSubQuery, err) + + // will raise error + _, _, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnexpectedSubQuery, err) + + // from a sub-query in different dialect + _, _, err = MySQL().Select("sub.id").From( + Oracle().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrInconsistentDialect, err) + + // from a sub-query (dialect set up) + sql, args, err = MySQL().Select("sub.id").From( + MySQL().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT sub.id FROM (SELECT id FROM table1 WHERE a=?) sub WHERE b=?", sql) + assert.EqualValues(t, []interface{}{1, 1}, args) + + // from a sub-query (dialect not set up) + sql, args, err = MySQL().Select("sub.id").From( + Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT sub.id FROM (SELECT id FROM table1 WHERE a=?) sub WHERE b=?", sql) + assert.EqualValues(t, []interface{}{1, 1}, args) +} diff --git a/builder/builder_set_operations.go b/builder/builder_set_operations.go new file mode 100644 index 0000000..b2b4a3d --- /dev/null +++ b/builder/builder_set_operations.go @@ -0,0 +1,51 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "strings" +) + +func (b *Builder) setOpWriteTo(w Writer) error { + if b.limitation != nil || b.cond.IsValid() || + b.orderBy != "" || b.having != "" || b.groupBy != "" { + return ErrNotUnexpectedUnionConditions + } + + for idx, o := range b.setOps { + current := o.builder + if current.optype != selectType { + return ErrUnsupportedUnionMembers + } + + if len(b.setOps) == 1 { + if err := current.selectWriteTo(w); err != nil { + return err + } + } else { + if b.dialect != "" && b.dialect != current.dialect { + return ErrInconsistentDialect + } + + if idx != 0 { + if o.distinctType == "" { + fmt.Fprint(w, fmt.Sprintf(" %s ", strings.ToUpper(o.opType))) + } else { + fmt.Fprint(w, fmt.Sprintf(" %s %s ", strings.ToUpper(o.opType), strings.ToUpper(o.distinctType))) + } + } + fmt.Fprint(w, "(") + + if err := current.selectWriteTo(w); err != nil { + return err + } + + fmt.Fprint(w, ")") + } + } + + return nil +} diff --git a/builder/builder_set_operations_test.go b/builder/builder_set_operations_test.go new file mode 100644 index 0000000..9ebfb6f --- /dev/null +++ b/builder/builder_set_operations_test.go @@ -0,0 +1,278 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuilder_Union(t *testing.T) { + sql, args, err := Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Union("distinct", Select("*").From("t2").Where(Eq{"status": "3"})). + Union("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) UNION ALL (SELECT * FROM t2 WHERE status=?) UNION DISTINCT (SELECT * FROM t2 WHERE status=?) UNION (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args) + + // sub-query will inherit dialect from the main one + sql, args, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + Union("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) UNION ALL (SELECT * FROM t2 WHERE status=? LIMIT 10) UNION (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3"}, args) + + // will raise error + _, _, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Oracle().Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrInconsistentDialect, err) + + // will raise error + _, _, err = Select("*").From("table1").Where(Eq{"a": "1"}). + Union("all", Select("*").From("table2").Where(Eq{"a": "2"})). + Where(Eq{"a": 2}).Limit(5, 10). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNotUnexpectedUnionConditions, err) + + // will raise error + _, _, err = Delete(Eq{"a": 1}).From("t1"). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"})).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnsupportedUnionMembers, err) + + // will be overwrote by SELECT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Select("*").From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by DELETE op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Delete(Eq{"status": "1"}).From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by INSERT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Insert(Eq{"status": "1"}).Into("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) +} + +func TestBuilder_Intersect(t *testing.T) { + sql, args, err := Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Intersect("distinct", Select("*").From("t2").Where(Eq{"status": "3"})). + Intersect("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) INTERSECT ALL (SELECT * FROM t2 WHERE status=?) INTERSECT DISTINCT (SELECT * FROM t2 WHERE status=?) INTERSECT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args) + + // sub-query will inherit dialect from the main one + sql, args, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + Intersect("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) INTERSECT ALL (SELECT * FROM t2 WHERE status=? LIMIT 10) INTERSECT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3"}, args) + + // will raise error + _, _, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Oracle().Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrInconsistentDialect, err) + + // will raise error + _, _, err = Select("*").From("table1").Where(Eq{"a": "1"}). + Intersect("all", Select("*").From("table2").Where(Eq{"a": "2"})). + Where(Eq{"a": 2}).Limit(5, 10). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNotUnexpectedUnionConditions, err) + + // will raise error + _, _, err = Delete(Eq{"a": 1}).From("t1"). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnsupportedUnionMembers, err) + + // will be overwrote by SELECT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Select("*").From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by DELETE op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Delete(Eq{"status": "1"}).From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by INSERT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Insert(Eq{"status": "1"}).Into("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) +} + +func TestBuilder_Except(t *testing.T) { + sql, args, err := Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Except("distinct", Select("*").From("t2").Where(Eq{"status": "3"})). + Except("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) EXCEPT ALL (SELECT * FROM t2 WHERE status=?) EXCEPT DISTINCT (SELECT * FROM t2 WHERE status=?) EXCEPT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args) + + // sub-query will inherit dialect from the main one + sql, args, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + Except("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) EXCEPT ALL (SELECT * FROM t2 WHERE status=? LIMIT 10) EXCEPT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3"}, args) + + // will raise error + _, _, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Oracle().Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrInconsistentDialect, err) + + // will raise error + _, _, err = Select("*").From("table1").Where(Eq{"a": "1"}). + Except("all", Select("*").From("table2").Where(Eq{"a": "2"})). + Where(Eq{"a": 2}).Limit(5, 10). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNotUnexpectedUnionConditions, err) + + // will raise error + _, _, err = Delete(Eq{"a": 1}).From("t1"). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"})).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnsupportedUnionMembers, err) + + // will be overwrote by SELECT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Select("*").From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by DELETE op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Delete(Eq{"status": "1"}).From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by INSERT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Insert(Eq{"status": "1"}).Into("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) +} + +func TestBuilder_SetOperations(t *testing.T) { + sql, args, err := Select("*").From("t1").Where(Eq{"status": "1"}). + Union("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Intersect("distinct", Select("*").From("t2").Where(Eq{"status": "3"})). + Except("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) UNION ALL (SELECT * FROM t2 WHERE status=?) INTERSECT DISTINCT (SELECT * FROM t2 WHERE status=?) EXCEPT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args) + + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Union("distinct", Select("*").From("t2").Where(Eq{"status": "3"})). + Except("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) INTERSECT ALL (SELECT * FROM t2 WHERE status=?) UNION DISTINCT (SELECT * FROM t2 WHERE status=?) EXCEPT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args) + + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Except("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Intersect("distinct", Select("*").From("t2").Where(Eq{"status": "3"})). + Union("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) EXCEPT ALL (SELECT * FROM t2 WHERE status=?) INTERSECT DISTINCT (SELECT * FROM t2 WHERE status=?) UNION (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3", "3"}, args) + + // sub-query will inherit dialect from the main one + sql, args, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + Intersect("", Select("*").From("t2").Where(Eq{"status": "3"})). + ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "(SELECT * FROM t1 WHERE status=?) INTERSECT ALL (SELECT * FROM t2 WHERE status=? LIMIT 10) INTERSECT (SELECT * FROM t2 WHERE status=?)", sql) + assert.EqualValues(t, []interface{}{"1", "2", "3"}, args) + + // will raise error + _, _, err = MySQL().Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Oracle().Select("*").From("t2").Where(Eq{"status": "2"}).Limit(10)). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrInconsistentDialect, err) + + // will raise error + _, _, err = Select("*").From("table1").Where(Eq{"a": "1"}). + Intersect("all", Select("*").From("table2").Where(Eq{"a": "2"})). + Where(Eq{"a": 2}).Limit(5, 10). + ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNotUnexpectedUnionConditions, err) + + // will raise error + _, _, err = Delete(Eq{"a": 1}).From("t1"). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrUnsupportedUnionMembers, err) + + // will be overwrote by SELECT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Select("*").From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by DELETE op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Delete(Eq{"status": "1"}).From("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) + + // will be overwrote by INSERT op + sql, args, err = Select("*").From("t1").Where(Eq{"status": "1"}). + Intersect("all", Select("*").From("t2").Where(Eq{"status": "2"})). + Insert(Eq{"status": "1"}).Into("t2").ToSQL() + assert.NoError(t, err) + fmt.Println(sql, args) +} diff --git a/builder/builder_test.go b/builder/builder_test.go new file mode 100644 index 0000000..79c0636 --- /dev/null +++ b/builder/builder_test.go @@ -0,0 +1,657 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type MyInt int + +func TestBuilderCond(t *testing.T) { + var cases = []struct { + cond Cond + sql string + args []interface{} + }{ + { + Eq{"a": 1}.And(Like{"b", "c"}).Or(Eq{"a": 2}.And(Like{"b", "g"})), + "(a=? AND b LIKE ?) OR (a=? AND b LIKE ?)", + []interface{}{1, "%c%", 2, "%g%"}, + }, + { + Eq{"a": 1}.Or(Like{"b", "c"}).And(Eq{"a": 2}.Or(Like{"b", "g"})), + "(a=? OR b LIKE ?) AND (a=? OR b LIKE ?)", + []interface{}{1, "%c%", 2, "%g%"}, + }, + { + Eq{"d": []string{"e", "f"}}, + "d IN (?,?)", + []interface{}{"e", "f"}, + }, + { + Eq{"e": Select("id").From("f").Where(Eq{"g": 1})}, + "e=(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Eq{"e": Expr("SELECT id FROM f WHERE g=?", 1)}, + "e=(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Like{"a", "%1"}.And(Like{"b", "%2"}), + "a LIKE ? AND b LIKE ?", + []interface{}{"%1", "%2"}, + }, + { + Like{"a", "%1"}.Or(Like{"b", "%2"}), + "a LIKE ? OR b LIKE ?", + []interface{}{"%1", "%2"}, + }, + { + Neq{"d": "e"}.Or(Neq{"f": "g"}), + "d<>? OR f<>?", + []interface{}{"e", "g"}, + }, + { + Neq{"d": []string{"e", "f"}}, + "d NOT IN (?,?)", + []interface{}{"e", "f"}, + }, + { + Neq{"e": Select("id").From("f").Where(Eq{"g": 1})}, + "e<>(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Neq{"e": Expr("SELECT id FROM f WHERE g=?", 1)}, + "e<>(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Lt{"d": 3}, + "d?", + []interface{}{3}, + }, + { + Gt{"d": 3}.And(Gt{"e": 4}), + "d>? AND e>?", + []interface{}{3, 4}, + }, + { + Gt{"d": 3}.Or(Gt{"e": 4}), + "d>? OR e>?", + []interface{}{3, 4}, + }, + { + Gt{"e": Select("id").From("f").Where(Eq{"g": 1})}, + "e>(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Gt{"e": Expr("SELECT id FROM f WHERE g=?", 1)}, + "e>(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Gte{"d": 3}, + "d>=?", + []interface{}{3}, + }, + { + Gte{"d": 3}.And(Gte{"e": 4}), + "d>=? AND e>=?", + []interface{}{3, 4}, + }, + { + Gte{"d": 3}.Or(Gte{"e": 4}), + "d>=? OR e>=?", + []interface{}{3, 4}, + }, + { + Gte{"e": Select("id").From("f").Where(Eq{"g": 1})}, + "e>=(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Gte{"e": Expr("SELECT id FROM f WHERE g=?", 1)}, + "e>=(SELECT id FROM f WHERE g=?)", + []interface{}{1}, + }, + { + Between{"d", 0, 2}, + "d BETWEEN ? AND ?", + []interface{}{0, 2}, + }, + { + Between{"d", 0, Expr("CAST('2003-01-01' AS DATE)")}, + "d BETWEEN ? AND CAST('2003-01-01' AS DATE)", + []interface{}{0}, + }, + { + Between{"d", Expr("CAST('2003-01-01' AS DATE)"), 2}, + "d BETWEEN CAST('2003-01-01' AS DATE) AND ?", + []interface{}{2}, + }, + { + Between{"d", Expr("CAST('2003-01-01' AS DATE)"), Expr("CAST('2003-01-01' AS DATE)")}, + "d BETWEEN CAST('2003-01-01' AS DATE) AND CAST('2003-01-01' AS DATE)", + []interface{}{}, + }, + { + Between{"d", 0, 2}.And(Between{"e", 3, 4}), + "d BETWEEN ? AND ? AND e BETWEEN ? AND ?", + []interface{}{0, 2, 3, 4}, + }, + { + Between{"d", 0, 2}.Or(Between{"e", 3, 4}), + "d BETWEEN ? AND ? OR e BETWEEN ? AND ?", + []interface{}{0, 2, 3, 4}, + }, + { + Expr("a < ?", 1), + "a < ?", + []interface{}{1}, + }, + { + Expr("a < ?", 1).And(Eq{"b": 2}), + "(a < ?) AND b=?", + []interface{}{1, 2}, + }, + { + Expr("a < ?", 1).Or(Neq{"b": 2}), + "(a < ?) OR b<>?", + []interface{}{1, 2}, + }, + { + IsNull{"d"}, + "d IS NULL", + []interface{}{}, + }, + { + IsNull{"d"}.And(IsNull{"e"}), + "d IS NULL AND e IS NULL", + []interface{}{}, + }, + { + IsNull{"d"}.Or(IsNull{"e"}), + "d IS NULL OR e IS NULL", + []interface{}{}, + }, + { + NotNull{"d"}, + "d IS NOT NULL", + []interface{}{}, + }, + { + NotNull{"d"}.And(NotNull{"e"}), + "d IS NOT NULL AND e IS NOT NULL", + []interface{}{}, + }, + { + NotNull{"d"}.Or(NotNull{"e"}), + "d IS NOT NULL OR e IS NOT NULL", + []interface{}{}, + }, + { + NotIn("a", 1, 2).And(NotIn("b", "c", "d")), + "a NOT IN (?,?) AND b NOT IN (?,?)", + []interface{}{1, 2, "c", "d"}, + }, + { + In("a", 1, 2).Or(In("b", "c", "d")), + "a IN (?,?) OR b IN (?,?)", + []interface{}{1, 2, "c", "d"}, + }, + { + In("a", []int{1, 2}).Or(In("b", []string{"c", "d"})), + "a IN (?,?) OR b IN (?,?)", + []interface{}{1, 2, "c", "d"}, + }, + { + In("a", Expr("select id from x where name > ?", "b")), + "a IN (select id from x where name > ?)", + []interface{}{"b"}, + }, + { + In("a", []MyInt{1, 2}).Or(In("b", []string{"c", "d"})), + "a IN (?,?) OR b IN (?,?)", + []interface{}{MyInt(1), MyInt(2), "c", "d"}, + }, + { + In("a", []int{}), + "0=1", + []interface{}{}, + }, + { + In("a", []int{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []int8{}), + "0=1", + []interface{}{}, + }, + { + In("a", []int8{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []int16{}), + "0=1", + []interface{}{}, + }, + { + In("a", []int16{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []int32{}), + "0=1", + []interface{}{}, + }, + { + In("a", []int32{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []int64{}), + "0=1", + []interface{}{}, + }, + { + In("a", []int64{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []uint{}), + "0=1", + []interface{}{}, + }, + { + In("a", []uint{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []uint8{}), + "0=1", + []interface{}{}, + }, + { + In("a", []uint8{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []uint16{}), + "0=1", + []interface{}{}, + }, + { + In("a", []uint16{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []uint32{}), + "0=1", + []interface{}{}, + }, + { + In("a", []uint32{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []uint64{}), + "0=1", + []interface{}{}, + }, + { + In("a", []uint64{1}), + "a IN (?)", + []interface{}{1}, + }, + { + In("a", []string{}), + "0=1", + []interface{}{}, + }, + { + In("a", []interface{}{}), + "0=1", + []interface{}{}, + }, + { + In("a", []MyInt{}), + "0=1", + []interface{}{}, + }, + { + In("a", []interface{}{1, 2, 3}).And(Eq{"b": "c"}), + "a IN (?,?,?) AND b=?", + []interface{}{1, 2, 3, "c"}, + }, + { + In("a", Select("id").From("b").Where(Eq{"c": 1})), + "a IN (SELECT id FROM b WHERE c=?)", + []interface{}{1}, + }, + { + NotIn("a", Expr("select id from x where name > ?", "b")), + "a NOT IN (select id from x where name > ?)", + []interface{}{"b"}, + }, + { + NotIn("a", []int{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []int{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []int8{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []int8{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []int16{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []int16{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []int32{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []int32{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []int64{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []int64{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []uint{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []uint{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []uint8{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []uint8{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []uint16{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []uint16{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []uint32{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []uint32{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []uint64{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []uint64{1}), + "a NOT IN (?)", + []interface{}{1}, + }, + { + NotIn("a", []interface{}{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []string{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []MyInt{}), + "0=0", + []interface{}{}, + }, + { + NotIn("a", []MyInt{1, 2}), + "a NOT IN (?,?)", + []interface{}{1, 2}, + }, + { + NotIn("a", []interface{}{1, 2, 3}).And(Eq{"b": "c"}), + "a NOT IN (?,?,?) AND b=?", + []interface{}{1, 2, 3, "c"}, + }, + { + NotIn("a", []interface{}{1, 2, 3}).Or(Eq{"b": "c"}), + "a NOT IN (?,?,?) OR b=?", + []interface{}{1, 2, 3, "c"}, + }, + { + NotIn("a", Select("id").From("b").Where(Eq{"c": 1})), + "a NOT IN (SELECT id FROM b WHERE c=?)", + []interface{}{1}, + }, + { + Or(Eq{"a": 1, "b": 2}, Eq{"c": 3, "d": 4}), + "(a=? AND b=?) OR (c=? AND d=?)", + []interface{}{1, 2, 3, 4}, + }, + { + Not{Eq{"a": 1, "b": 2}}, + "NOT (a=? AND b=?)", + []interface{}{1, 2}, + }, + { + Not{Neq{"a": 1, "b": 2}}, + "NOT (a<>? AND b<>?)", + []interface{}{1, 2}, + }, + { + Not{Eq{"a": 1}.And(Eq{"b": 2})}, + "NOT (a=? AND b=?)", + []interface{}{1, 2}, + }, + { + Not{Neq{"a": 1}.And(Neq{"b": 2})}, + "NOT (a<>? AND b<>?)", + []interface{}{1, 2}, + }, + { + Not{Eq{"a": 1}}.And(Neq{"b": 2}), + "NOT a=? AND b<>?", + []interface{}{1, 2}, + }, + { + Not{Eq{"a": 1}}.Or(Neq{"b": 2}), + "NOT a=? OR b<>?", + []interface{}{1, 2}, + }, + } + + for _, k := range cases { + sql, args, err := ToSQL(k.cond) + assert.NoError(t, err) + assert.EqualValues(t, k.sql, sql) + + for i := 0; i < 10; i++ { + sql2, _, err := ToSQL(k.cond) + assert.NoError(t, err) + assert.EqualValues(t, sql, sql2) + } + + assert.EqualValues(t, len(args), len(k.args)) + + if len(args) > 0 { + for i := 0; i < len(args); i++ { + assert.EqualValues(t, k.args[i], args[i]) + } + } + } +} + +func TestSubquery(t *testing.T) { + subb := Select("id").From("table_b").Where(Eq{"b": "a"}) + b := Select("a, b").From("table_a").Where( + Eq{ + "b_id": subb, + "id": 23, + }, + ) + sql, args, err := b.ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT a, b FROM table_a WHERE b_id=(SELECT id FROM table_b WHERE b=?) AND id=?", sql) + assert.EqualValues(t, []interface{}{"a", 23}, args) +} + +// https://github.com/go-xorm/xorm/issues/820 +func TestExprCond(t *testing.T) { + b := Select("id").From("table1").Where(expr{sql: "a=? OR b=?", args: []interface{}{1, 2}}).Where(Or(Eq{"c": 3}, Eq{"d": 4})) + sql, args, err := b.ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "table1", b.TableName()) + assert.EqualValues(t, "SELECT id FROM table1 WHERE (a=? OR b=?) AND (c=? OR d=?)", sql) + assert.EqualValues(t, []interface{}{1, 2, 3, 4}, args) +} + +func TestBuilder_ToBoundSQL(t *testing.T) { + newSQL, err := Select("id").From("table").Where(In("a", 1, 2)).ToBoundSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT id FROM table WHERE a IN (1,2)", newSQL) +} + +func TestBuilder_From2(t *testing.T) { + b := Select("id").From("table_b", "tb").Where(Eq{"b": "a"}) + sql, args, err := b.ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT id FROM table_b tb WHERE b=?", sql) + assert.EqualValues(t, []interface{}{"a"}, args) + + b = Select().From("table_b", "tb").Where(Eq{"b": "a"}) + sql, args, err = b.ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT * FROM table_b tb WHERE b=?", sql) + assert.EqualValues(t, []interface{}{"a"}, args) +} + +func TestBuilder_And(t *testing.T) { + b := Select("id").From("table_b", "tb").Where(Eq{"b": "a"}).And(Neq{"c": "d"}) + sql, args, err := b.ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT id FROM table_b tb WHERE b=? AND c<>?", sql) + assert.EqualValues(t, []interface{}{"a", "d"}, args) +} + +func TestBuilder_Or(t *testing.T) { + b := Select("id").From("table_b", "tb").Where(Eq{"b": "a"}).Or(Neq{"c": "d"}) + sql, args, err := b.ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "SELECT id FROM table_b tb WHERE b=? OR c<>?", sql) + assert.EqualValues(t, []interface{}{"a", "d"}, args) +} diff --git a/builder/builder_update.go b/builder/builder_update.go new file mode 100644 index 0000000..5fffbe3 --- /dev/null +++ b/builder/builder_update.go @@ -0,0 +1,57 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" +) + +// UpdateCond defines an interface that cond could be used with update +type UpdateCond interface { + IsValid() bool + OpWriteTo(op string, w Writer) error +} + +// Update creates an update Builder +func Update(updates ...Cond) *Builder { + builder := &Builder{cond: NewCond()} + return builder.Update(updates...) +} + +func (b *Builder) updateWriteTo(w Writer) error { + if len(b.from) <= 0 { + return ErrNoTableName + } + if len(b.updates) <= 0 { + return ErrNoColumnToUpdate + } + + if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.from); err != nil { + return err + } + + for i, s := range b.updates { + + if err := s.OpWriteTo(",", w); err != nil { + return err + } + + if i != len(b.updates)-1 { + if _, err := fmt.Fprint(w, ","); err != nil { + return err + } + } + } + + if !b.cond.IsValid() { + return nil + } + + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + + return b.cond.WriteTo(w) +} diff --git a/builder/builder_update_test.go b/builder/builder_update_test.go new file mode 100644 index 0000000..9f3fc2e --- /dev/null +++ b/builder/builder_update_test.go @@ -0,0 +1,62 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuilderUpdate(t *testing.T) { + sql, args, err := Update(Eq{"a": 2}).From("table1").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table1 SET a=? WHERE a=?", sql) + assert.EqualValues(t, []interface{}{2, 1}, args) + + sql, args, err = Update(Eq{"a": 2, "b": 1}).From("table1").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table1 SET a=?,b=? WHERE a=?", sql) + assert.EqualValues(t, []interface{}{2, 1, 1}, args) + + sql, args, err = Update(Eq{"a": 2}, Eq{"b": 1}).From("table1").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table1 SET a=?,b=? WHERE a=?", sql) + assert.EqualValues(t, []interface{}{2, 1, 1}, args) + + sql, args, err = Update(Eq{"a": 2, "b": Incr(1)}).From("table2").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table2 SET a=?,b=b+? WHERE a=?", sql) + assert.EqualValues(t, []interface{}{2, 1, 1}, args) + + sql, args, err = Update(Eq{"a": 2, "b": Incr(1), "c": Decr(1), "d": Expr("select count(*) from table2")}).From("table2").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table2 SET a=?,b=b+?,c=c-?,d=(select count(*) from table2) WHERE a=?", sql) + assert.EqualValues(t, []interface{}{2, 1, 1, 1}, args) + + sql, args, err = Update(Eq{"a": 2}).Where(Eq{"a": 1}).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNoTableName, err) + + sql, args, err = Update(Eq{}).From("table1").Where(Eq{"a": 1}).ToSQL() + assert.Error(t, err) + assert.EqualValues(t, ErrNoColumnToUpdate, err) + + var builder = Builder{cond: NewCond()} + sql, args, err = builder.Update(Eq{"a": 2, "b": 1}).From("table1").Where(Eq{"a": 1}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table1 SET a=?,b=? WHERE a=?", sql) + assert.EqualValues(t, []interface{}{2, 1, 1}, args) + + sql, args, err = Update(Eq{"a": 1}, Expr("c = c+1")).From("table1").Where(Eq{"b": 2}).ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table1 SET a=?,c = c+1 WHERE b=?", sql) + assert.EqualValues(t, []interface{}{1, 2}, args) + + sql, args, err = Update(Eq{"a": 2}).From("table1").ToSQL() + assert.NoError(t, err) + assert.EqualValues(t, "UPDATE table1 SET a=?", sql) + assert.EqualValues(t, []interface{}{2}, args) +} diff --git a/builder/cond.go b/builder/cond.go new file mode 100644 index 0000000..149f5d8 --- /dev/null +++ b/builder/cond.go @@ -0,0 +1,38 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +// Cond defines an interface +type Cond interface { + WriteTo(Writer) error + And(...Cond) Cond + Or(...Cond) Cond + IsValid() bool +} + +type condEmpty struct{} + +var _ Cond = condEmpty{} + +// NewCond creates an empty condition +func NewCond() Cond { + return condEmpty{} +} + +func (condEmpty) WriteTo(w Writer) error { + return nil +} + +func (condEmpty) And(conds ...Cond) Cond { + return And(conds...) +} + +func (condEmpty) Or(conds ...Cond) Cond { + return Or(conds...) +} + +func (condEmpty) IsValid() bool { + return false +} diff --git a/builder/cond_and.go b/builder/cond_and.go new file mode 100644 index 0000000..e30bd18 --- /dev/null +++ b/builder/cond_and.go @@ -0,0 +1,61 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +type condAnd []Cond + +var _ Cond = condAnd{} + +// And generates AND conditions +func And(conds ...Cond) Cond { + var result = make(condAnd, 0, len(conds)) + for _, cond := range conds { + if cond == nil || !cond.IsValid() { + continue + } + result = append(result, cond) + } + return result +} + +func (and condAnd) WriteTo(w Writer) error { + for i, cond := range and { + _, isOr := cond.(condOr) + _, isExpr := cond.(expr) + wrap := isOr || isExpr + if wrap { + fmt.Fprint(w, "(") + } + + err := cond.WriteTo(w) + if err != nil { + return err + } + + if wrap { + fmt.Fprint(w, ")") + } + + if i != len(and)-1 { + fmt.Fprint(w, " AND ") + } + } + + return nil +} + +func (and condAnd) And(conds ...Cond) Cond { + return And(and, And(conds...)) +} + +func (and condAnd) Or(conds ...Cond) Cond { + return Or(and, Or(conds...)) +} + +func (and condAnd) IsValid() bool { + return len(and) > 0 +} diff --git a/builder/cond_between.go b/builder/cond_between.go new file mode 100644 index 0000000..10e0b83 --- /dev/null +++ b/builder/cond_between.go @@ -0,0 +1,65 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +// Between implmentes between condition +type Between struct { + Col string + LessVal interface{} + MoreVal interface{} +} + +var _ Cond = Between{} + +// WriteTo write data to Writer +func (between Between) WriteTo(w Writer) error { + if _, err := fmt.Fprintf(w, "%s BETWEEN ", between.Col); err != nil { + return err + } + if lv, ok := between.LessVal.(expr); ok { + if err := lv.WriteTo(w); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, "?"); err != nil { + return err + } + w.Append(between.LessVal) + } + + if _, err := fmt.Fprint(w, " AND "); err != nil { + return err + } + + if mv, ok := between.MoreVal.(expr); ok { + if err := mv.WriteTo(w); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, "?"); err != nil { + return err + } + w.Append(between.MoreVal) + } + + return nil +} + +// And implments And with other conditions +func (between Between) And(conds ...Cond) Cond { + return And(between, And(conds...)) +} + +// Or implments Or with other conditions +func (between Between) Or(conds ...Cond) Cond { + return Or(between, Or(conds...)) +} + +// IsValid tests if the condition is valid +func (between Between) IsValid() bool { + return len(between.Col) > 0 +} diff --git a/builder/cond_compare.go b/builder/cond_compare.go new file mode 100644 index 0000000..1c29371 --- /dev/null +++ b/builder/cond_compare.go @@ -0,0 +1,160 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +// WriteMap writes conditions' SQL to Writer, op could be =, <>, >, <, <=, >= and etc. +func WriteMap(w Writer, data map[string]interface{}, op string) error { + var args = make([]interface{}, 0, len(data)) + var i = 0 + keys := make([]string, 0, len(data)) + for k := range data { + keys = append(keys, k) + } + + for _, k := range keys { + v := data[k] + switch v.(type) { + case expr: + if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { + return err + } + + if err := v.(expr).WriteTo(w); err != nil { + return err + } + + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + case *Builder: + if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil { + return err + } + + if err := v.(*Builder).WriteTo(w); err != nil { + return err + } + + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + default: + if _, err := fmt.Fprintf(w, "%s%s?", k, op); err != nil { + return err + } + args = append(args, v) + } + if i != len(data)-1 { + if _, err := fmt.Fprint(w, " AND "); err != nil { + return err + } + } + i = i + 1 + } + w.Append(args...) + return nil +} + +// Lt defines < condition +type Lt map[string]interface{} + +var _ Cond = Lt{} + +// WriteTo write SQL to Writer +func (lt Lt) WriteTo(w Writer) error { + return WriteMap(w, lt, "<") +} + +// And implements And with other conditions +func (lt Lt) And(conds ...Cond) Cond { + return condAnd{lt, And(conds...)} +} + +// Or implements Or with other conditions +func (lt Lt) Or(conds ...Cond) Cond { + return condOr{lt, Or(conds...)} +} + +// IsValid tests if this Eq is valid +func (lt Lt) IsValid() bool { + return len(lt) > 0 +} + +// Lte defines <= condition +type Lte map[string]interface{} + +var _ Cond = Lte{} + +// WriteTo write SQL to Writer +func (lte Lte) WriteTo(w Writer) error { + return WriteMap(w, lte, "<=") +} + +// And implements And with other conditions +func (lte Lte) And(conds ...Cond) Cond { + return And(lte, And(conds...)) +} + +// Or implements Or with other conditions +func (lte Lte) Or(conds ...Cond) Cond { + return Or(lte, Or(conds...)) +} + +// IsValid tests if this Eq is valid +func (lte Lte) IsValid() bool { + return len(lte) > 0 +} + +// Gt defines > condition +type Gt map[string]interface{} + +var _ Cond = Gt{} + +// WriteTo write SQL to Writer +func (gt Gt) WriteTo(w Writer) error { + return WriteMap(w, gt, ">") +} + +// And implements And with other conditions +func (gt Gt) And(conds ...Cond) Cond { + return And(gt, And(conds...)) +} + +// Or implements Or with other conditions +func (gt Gt) Or(conds ...Cond) Cond { + return Or(gt, Or(conds...)) +} + +// IsValid tests if this Eq is valid +func (gt Gt) IsValid() bool { + return len(gt) > 0 +} + +// Gte defines >= condition +type Gte map[string]interface{} + +var _ Cond = Gte{} + +// WriteTo write SQL to Writer +func (gte Gte) WriteTo(w Writer) error { + return WriteMap(w, gte, ">=") +} + +// And implements And with other conditions +func (gte Gte) And(conds ...Cond) Cond { + return And(gte, And(conds...)) +} + +// Or implements Or with other conditions +func (gte Gte) Or(conds ...Cond) Cond { + return Or(gte, Or(conds...)) +} + +// IsValid tests if this Eq is valid +func (gte Gte) IsValid() bool { + return len(gte) > 0 +} diff --git a/builder/cond_eq.go b/builder/cond_eq.go new file mode 100644 index 0000000..9976d18 --- /dev/null +++ b/builder/cond_eq.go @@ -0,0 +1,117 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "sort" +) + +// Incr implements a type used by Eq +type Incr int + +// Decr implements a type used by Eq +type Decr int + +// Eq defines equals conditions +type Eq map[string]interface{} + +var _ Cond = Eq{} + +// OpWriteTo writes conditions with special operator +func (eq Eq) OpWriteTo(op string, w Writer) error { + var i = 0 + for _, k := range eq.sortedKeys() { + v := eq[k] + switch v.(type) { + case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}: + if err := In(k, v).WriteTo(w); err != nil { + return err + } + case expr: + if _, err := fmt.Fprintf(w, "%s=(", k); err != nil { + return err + } + + if err := v.(expr).WriteTo(w); err != nil { + return err + } + + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + case *Builder: + if _, err := fmt.Fprintf(w, "%s=(", k); err != nil { + return err + } + + if err := v.(*Builder).WriteTo(w); err != nil { + return err + } + + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + case Incr: + if _, err := fmt.Fprintf(w, "%s=%s+?", k, k); err != nil { + return err + } + w.Append(int(v.(Incr))) + case Decr: + if _, err := fmt.Fprintf(w, "%s=%s-?", k, k); err != nil { + return err + } + w.Append(int(v.(Decr))) + case nil: + if _, err := fmt.Fprintf(w, "%s=null", k); err != nil { + return err + } + default: + if _, err := fmt.Fprintf(w, "%s=?", k); err != nil { + return err + } + w.Append(v) + } + if i != len(eq)-1 { + if _, err := fmt.Fprint(w, op); err != nil { + return err + } + } + i = i + 1 + } + return nil +} + +// WriteTo writes SQL to Writer +func (eq Eq) WriteTo(w Writer) error { + return eq.OpWriteTo(" AND ", w) +} + +// And implements And with other conditions +func (eq Eq) And(conds ...Cond) Cond { + return And(eq, And(conds...)) +} + +// Or implements Or with other conditions +func (eq Eq) Or(conds ...Cond) Cond { + return Or(eq, Or(conds...)) +} + +// IsValid tests if this Eq is valid +func (eq Eq) IsValid() bool { + return len(eq) > 0 +} + +// sortedKeys returns all keys of this Eq sorted with sort.Strings. +// It is used internally for consistent ordering when generating +// SQL, see https://gitea.com/xorm/builder/issues/10 +func (eq Eq) sortedKeys() []string { + keys := make([]string, 0, len(eq)) + for key := range eq { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} diff --git a/builder/cond_expr.go b/builder/cond_expr.go new file mode 100644 index 0000000..8288aa0 --- /dev/null +++ b/builder/cond_expr.go @@ -0,0 +1,43 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +type expr struct { + sql string + args []interface{} +} + +var _ Cond = expr{} + +// Expr generate customerize SQL +func Expr(sql string, args ...interface{}) Cond { + return expr{sql, args} +} + +func (expr expr) OpWriteTo(op string, w Writer) error { + return expr.WriteTo(w) +} + +func (expr expr) WriteTo(w Writer) error { + if _, err := fmt.Fprint(w, expr.sql); err != nil { + return err + } + w.Append(expr.args...) + return nil +} + +func (expr expr) And(conds ...Cond) Cond { + return And(expr, And(conds...)) +} + +func (expr expr) Or(conds ...Cond) Cond { + return Or(expr, Or(conds...)) +} + +func (expr expr) IsValid() bool { + return len(expr.sql) > 0 +} diff --git a/builder/cond_if.go b/builder/cond_if.go new file mode 100644 index 0000000..af9eb32 --- /dev/null +++ b/builder/cond_if.go @@ -0,0 +1,49 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +type condIf struct { + condition bool + condTrue Cond + condFalse Cond +} + +var _ Cond = condIf{} + +// If returns Cond via condition +func If(condition bool, condTrue Cond, condFalse ...Cond) Cond { + var c = condIf{ + condition: condition, + condTrue: condTrue, + } + if len(condFalse) > 0 { + c.condFalse = condFalse[0] + } + return c +} + +func (condIf condIf) WriteTo(w Writer) error { + if condIf.condition { + return condIf.condTrue.WriteTo(w) + } else if condIf.condFalse != nil { + return condIf.condFalse.WriteTo(w) + } + return nil +} + +func (condIf condIf) And(conds ...Cond) Cond { + return And(condIf, And(conds...)) +} + +func (condIf condIf) Or(conds ...Cond) Cond { + return Or(condIf, Or(conds...)) +} + +func (condIf condIf) IsValid() bool { + if condIf.condition { + return condIf.condTrue != nil + } + return condIf.condFalse != nil +} diff --git a/builder/cond_if_test.go b/builder/cond_if_test.go new file mode 100644 index 0000000..0a268c4 --- /dev/null +++ b/builder/cond_if_test.go @@ -0,0 +1,38 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCond_If(t *testing.T) { + cond1 := If(1 > 0, Eq{"a": 1}, Eq{"b": 1}) + sql, err := ToBoundSQL(cond1) + assert.NoError(t, err) + assert.EqualValues(t, "a=1", sql) + + cond2 := If(1 < 0, Eq{"a": 1}, Eq{"b": 1}) + sql, err = ToBoundSQL(cond2) + assert.NoError(t, err) + assert.EqualValues(t, "b=1", sql) + + cond3 := If(1 > 0, cond2, Eq{"c": 1}) + sql, err = ToBoundSQL(cond3) + assert.NoError(t, err) + assert.EqualValues(t, "b=1", sql) + + cond4 := If(2 < 0, Eq{"d": "a"}) + sql, err = ToBoundSQL(cond4) + assert.NoError(t, err) + assert.EqualValues(t, "", sql) + + cond5 := And(cond1, cond2, cond3, cond4) + sql, err = ToBoundSQL(cond5) + assert.NoError(t, err) + assert.EqualValues(t, "a=1 AND b=1 AND b=1", sql) +} diff --git a/builder/cond_in.go b/builder/cond_in.go new file mode 100644 index 0000000..f6366d3 --- /dev/null +++ b/builder/cond_in.go @@ -0,0 +1,237 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "reflect" + "strings" +) + +type condIn struct { + col string + vals []interface{} +} + +var _ Cond = condIn{} + +// In generates IN condition +func In(col string, values ...interface{}) Cond { + return condIn{col, values} +} + +func (condIn condIn) handleBlank(w Writer) error { + _, err := fmt.Fprint(w, "0=1") + return err +} + +func (condIn condIn) WriteTo(w Writer) error { + if len(condIn.vals) <= 0 { + return condIn.handleBlank(w) + } + + switch condIn.vals[0].(type) { + case []int8: + vals := condIn.vals[0].([]int8) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int16: + vals := condIn.vals[0].([]int16) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int: + vals := condIn.vals[0].([]int) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int32: + vals := condIn.vals[0].([]int32) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int64: + vals := condIn.vals[0].([]int64) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint8: + vals := condIn.vals[0].([]uint8) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint16: + vals := condIn.vals[0].([]uint16) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint: + vals := condIn.vals[0].([]uint) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint32: + vals := condIn.vals[0].([]uint32) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint64: + vals := condIn.vals[0].([]uint64) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []string: + vals := condIn.vals[0].([]string) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []interface{}: + vals := condIn.vals[0].([]interface{}) + if len(vals) <= 0 { + return condIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + w.Append(vals...) + case expr: + val := condIn.vals[0].(expr) + if _, err := fmt.Fprintf(w, "%s IN (", condIn.col); err != nil { + return err + } + if err := val.WriteTo(w); err != nil { + return err + } + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + case *Builder: + bd := condIn.vals[0].(*Builder) + if _, err := fmt.Fprintf(w, "%s IN (", condIn.col); err != nil { + return err + } + if err := bd.WriteTo(w); err != nil { + return err + } + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + default: + v := reflect.ValueOf(condIn.vals[0]) + if v.Kind() == reflect.Slice { + l := v.Len() + if l == 0 { + return condIn.handleBlank(w) + } + + questionMark := strings.Repeat("?,", l) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + + for i := 0; i < l; i++ { + w.Append(v.Index(i).Interface()) + } + } else { + questionMark := strings.Repeat("?,", len(condIn.vals)) + if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + w.Append(condIn.vals...) + } + } + return nil +} + +func (condIn condIn) And(conds ...Cond) Cond { + return And(condIn, And(conds...)) +} + +func (condIn condIn) Or(conds ...Cond) Cond { + return Or(condIn, Or(conds...)) +} + +func (condIn condIn) IsValid() bool { + return len(condIn.col) > 0 && len(condIn.vals) > 0 +} diff --git a/builder/cond_like.go b/builder/cond_like.go new file mode 100644 index 0000000..e34202f --- /dev/null +++ b/builder/cond_like.go @@ -0,0 +1,41 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +// Like defines like condition +type Like [2]string + +var _ Cond = Like{"", ""} + +// WriteTo write SQL to Writer +func (like Like) WriteTo(w Writer) error { + if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil { + return err + } + // FIXME: if use other regular express, this will be failed. but for compatible, keep this + if like[1][0] == '%' || like[1][len(like[1])-1] == '%' { + w.Append(like[1]) + } else { + w.Append("%" + like[1] + "%") + } + return nil +} + +// And implements And with other conditions +func (like Like) And(conds ...Cond) Cond { + return And(like, And(conds...)) +} + +// Or implements Or with other conditions +func (like Like) Or(conds ...Cond) Cond { + return Or(like, Or(conds...)) +} + +// IsValid tests if this condition is valid +func (like Like) IsValid() bool { + return len(like[0]) > 0 && len(like[1]) > 0 +} diff --git a/builder/cond_neq.go b/builder/cond_neq.go new file mode 100644 index 0000000..687c59f --- /dev/null +++ b/builder/cond_neq.go @@ -0,0 +1,94 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "sort" +) + +// Neq defines not equal conditions +type Neq map[string]interface{} + +var _ Cond = Neq{} + +// WriteTo writes SQL to Writer +func (neq Neq) WriteTo(w Writer) error { + var args = make([]interface{}, 0, len(neq)) + var i = 0 + for _, k := range neq.sortedKeys() { + v := neq[k] + switch v.(type) { + case []int, []int64, []string, []int32, []int16, []int8: + if err := NotIn(k, v).WriteTo(w); err != nil { + return err + } + case expr: + if _, err := fmt.Fprintf(w, "%s<>(", k); err != nil { + return err + } + + if err := v.(expr).WriteTo(w); err != nil { + return err + } + + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + case *Builder: + if _, err := fmt.Fprintf(w, "%s<>(", k); err != nil { + return err + } + + if err := v.(*Builder).WriteTo(w); err != nil { + return err + } + + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + default: + if _, err := fmt.Fprintf(w, "%s<>?", k); err != nil { + return err + } + args = append(args, v) + } + if i != len(neq)-1 { + if _, err := fmt.Fprint(w, " AND "); err != nil { + return err + } + } + i = i + 1 + } + w.Append(args...) + return nil +} + +// And implements And with other conditions +func (neq Neq) And(conds ...Cond) Cond { + return And(neq, And(conds...)) +} + +// Or implements Or with other conditions +func (neq Neq) Or(conds ...Cond) Cond { + return Or(neq, Or(conds...)) +} + +// IsValid tests if this condition is valid +func (neq Neq) IsValid() bool { + return len(neq) > 0 +} + +// sortedKeys returns all keys of this Neq sorted with sort.Strings. +// It is used internally for consistent ordering when generating +// SQL, see https://gitea.com/xorm/builder/issues/10 +func (neq Neq) sortedKeys() []string { + keys := make([]string, 0, len(neq)) + for key := range neq { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} diff --git a/builder/cond_not.go b/builder/cond_not.go new file mode 100644 index 0000000..667dfe7 --- /dev/null +++ b/builder/cond_not.go @@ -0,0 +1,77 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +// Not defines NOT condition +type Not [1]Cond + +var _ Cond = Not{} + +// WriteTo writes SQL to Writer +func (not Not) WriteTo(w Writer) error { + if _, err := fmt.Fprint(w, "NOT "); err != nil { + return err + } + switch not[0].(type) { + case condAnd, condOr: + if _, err := fmt.Fprint(w, "("); err != nil { + return err + } + case Eq: + if len(not[0].(Eq)) > 1 { + if _, err := fmt.Fprint(w, "("); err != nil { + return err + } + } + case Neq: + if len(not[0].(Neq)) > 1 { + if _, err := fmt.Fprint(w, "("); err != nil { + return err + } + } + } + + if err := not[0].WriteTo(w); err != nil { + return err + } + + switch not[0].(type) { + case condAnd, condOr: + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + case Eq: + if len(not[0].(Eq)) > 1 { + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + } + case Neq: + if len(not[0].(Neq)) > 1 { + if _, err := fmt.Fprint(w, ")"); err != nil { + return err + } + } + } + + return nil +} + +// And implements And with other conditions +func (not Not) And(conds ...Cond) Cond { + return And(not, And(conds...)) +} + +// Or implements Or with other conditions +func (not Not) Or(conds ...Cond) Cond { + return Or(not, Or(conds...)) +} + +// IsValid tests if this condition is valid +func (not Not) IsValid() bool { + return not[0] != nil && not[0].IsValid() +} diff --git a/builder/cond_notin.go b/builder/cond_notin.go new file mode 100644 index 0000000..dc3ac49 --- /dev/null +++ b/builder/cond_notin.go @@ -0,0 +1,234 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "fmt" + "reflect" + "strings" +) + +type condNotIn condIn + +var _ Cond = condNotIn{} + +// NotIn generate NOT IN condition +func NotIn(col string, values ...interface{}) Cond { + return condNotIn{col, values} +} + +func (condNotIn condNotIn) handleBlank(w Writer) error { + _, err := fmt.Fprint(w, "0=0") + return err +} + +func (condNotIn condNotIn) WriteTo(w Writer) error { + if len(condNotIn.vals) <= 0 { + return condNotIn.handleBlank(w) + } + + switch condNotIn.vals[0].(type) { + case []int8: + vals := condNotIn.vals[0].([]int8) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int16: + vals := condNotIn.vals[0].([]int16) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int: + vals := condNotIn.vals[0].([]int) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int32: + vals := condNotIn.vals[0].([]int32) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []int64: + vals := condNotIn.vals[0].([]int64) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint8: + vals := condNotIn.vals[0].([]uint8) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint16: + vals := condNotIn.vals[0].([]uint16) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint: + vals := condNotIn.vals[0].([]uint) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint32: + vals := condNotIn.vals[0].([]uint32) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []uint64: + vals := condNotIn.vals[0].([]uint64) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []string: + vals := condNotIn.vals[0].([]string) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + for _, val := range vals { + w.Append(val) + } + case []interface{}: + vals := condNotIn.vals[0].([]interface{}) + if len(vals) <= 0 { + return condNotIn.handleBlank(w) + } + questionMark := strings.Repeat("?,", len(vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + w.Append(vals...) + case expr: + val := condNotIn.vals[0].(expr) + if _, err := fmt.Fprintf(w, "%s NOT IN (", condNotIn.col); err != nil { + return err + } + if err := val.WriteTo(w); err != nil { + return err + } + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + case *Builder: + val := condNotIn.vals[0].(*Builder) + if _, err := fmt.Fprintf(w, "%s NOT IN (", condNotIn.col); err != nil { + return err + } + if err := val.WriteTo(w); err != nil { + return err + } + if _, err := fmt.Fprintf(w, ")"); err != nil { + return err + } + default: + v := reflect.ValueOf(condNotIn.vals[0]) + if v.Kind() == reflect.Slice { + l := v.Len() + if l == 0 { + return condNotIn.handleBlank(w) + } + + questionMark := strings.Repeat("?,", l) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + + for i := 0; i < l; i++ { + w.Append(v.Index(i).Interface()) + } + } else { + questionMark := strings.Repeat("?,", len(condNotIn.vals)) + if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil { + return err + } + w.Append(condNotIn.vals...) + } + } + return nil +} + +func (condNotIn condNotIn) And(conds ...Cond) Cond { + return And(condNotIn, And(conds...)) +} + +func (condNotIn condNotIn) Or(conds ...Cond) Cond { + return Or(condNotIn, Or(conds...)) +} + +func (condNotIn condNotIn) IsValid() bool { + return len(condNotIn.col) > 0 && len(condNotIn.vals) > 0 +} diff --git a/builder/cond_null.go b/builder/cond_null.go new file mode 100644 index 0000000..bf2aaf8 --- /dev/null +++ b/builder/cond_null.go @@ -0,0 +1,59 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +// IsNull defines IS NULL condition +type IsNull [1]string + +var _ Cond = IsNull{""} + +// WriteTo write SQL to Writer +func (isNull IsNull) WriteTo(w Writer) error { + _, err := fmt.Fprintf(w, "%s IS NULL", isNull[0]) + return err +} + +// And implements And with other conditions +func (isNull IsNull) And(conds ...Cond) Cond { + return And(isNull, And(conds...)) +} + +// Or implements Or with other conditions +func (isNull IsNull) Or(conds ...Cond) Cond { + return Or(isNull, Or(conds...)) +} + +// IsValid tests if this condition is valid +func (isNull IsNull) IsValid() bool { + return len(isNull[0]) > 0 +} + +// NotNull defines NOT NULL condition +type NotNull [1]string + +var _ Cond = NotNull{""} + +// WriteTo write SQL to Writer +func (notNull NotNull) WriteTo(w Writer) error { + _, err := fmt.Fprintf(w, "%s IS NOT NULL", notNull[0]) + return err +} + +// And implements And with other conditions +func (notNull NotNull) And(conds ...Cond) Cond { + return And(notNull, And(conds...)) +} + +// Or implements Or with other conditions +func (notNull NotNull) Or(conds ...Cond) Cond { + return Or(notNull, Or(conds...)) +} + +// IsValid tests if this condition is valid +func (notNull NotNull) IsValid() bool { + return len(notNull[0]) > 0 +} diff --git a/builder/cond_or.go b/builder/cond_or.go new file mode 100644 index 0000000..5244265 --- /dev/null +++ b/builder/cond_or.go @@ -0,0 +1,69 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "fmt" + +type condOr []Cond + +var _ Cond = condOr{} + +// Or sets OR conditions +func Or(conds ...Cond) Cond { + var result = make(condOr, 0, len(conds)) + for _, cond := range conds { + if cond == nil || !cond.IsValid() { + continue + } + result = append(result, cond) + } + return result +} + +// WriteTo implments Cond +func (o condOr) WriteTo(w Writer) error { + for i, cond := range o { + var needQuote bool + switch cond.(type) { + case condAnd, expr: + needQuote = true + case Eq: + needQuote = (len(cond.(Eq)) > 1) + case Neq: + needQuote = (len(cond.(Neq)) > 1) + } + + if needQuote { + fmt.Fprint(w, "(") + } + + err := cond.WriteTo(w) + if err != nil { + return err + } + + if needQuote { + fmt.Fprint(w, ")") + } + + if i != len(o)-1 { + fmt.Fprint(w, " OR ") + } + } + + return nil +} + +func (o condOr) And(conds ...Cond) Cond { + return And(o, And(conds...)) +} + +func (o condOr) Or(conds ...Cond) Cond { + return Or(o, Or(conds...)) +} + +func (o condOr) IsValid() bool { + return len(o) > 0 +} diff --git a/builder/cond_test.go b/builder/cond_test.go new file mode 100644 index 0000000..51f37e7 --- /dev/null +++ b/builder/cond_test.go @@ -0,0 +1,11 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "testing" + +func TestCond_NotIn(t *testing.T) { + +} diff --git a/builder/doc.go b/builder/doc.go new file mode 100644 index 0000000..6e7dd45 --- /dev/null +++ b/builder/doc.go @@ -0,0 +1,120 @@ +// Copyright 2016 The XORM Authors. All rights reserved. +// Use of this source code is governed by a BSD +// license that can be found in the LICENSE file. + +/* + +Package builder is a simple and powerful sql builder for Go. + +Make sure you have installed Go 1.1+ and then: + + go get xorm.io/builder + +WARNNING: Currently, only query conditions are supported. Below is the supported conditions. + +1. Eq is a redefine of a map, you can give one or more conditions to Eq + + import . "xorm.io/builder" + + sql, args, _ := ToSQL(Eq{"a":1}) + // a=? [1] + sql, args, _ := ToSQL(Eq{"b":"c"}.And(Eq{"c": 0})) + // b=? AND c=? ["c", 0] + sql, args, _ := ToSQL(Eq{"b":"c", "c":0}) + // b=? AND c=? ["c", 0] + sql, args, _ := ToSQL(Eq{"b":"c"}.Or(Eq{"b":"d"})) + // b=? OR b=? ["c", "d"] + sql, args, _ := ToSQL(Eq{"b": []string{"c", "d"}}) + // b IN (?,?) ["c", "d"] + sql, args, _ := ToSQL(Eq{"b": 1, "c":[]int{2, 3}}) + // b=? AND c IN (?,?) [1, 2, 3] + +2. Neq is the same to Eq + + import . "xorm.io/builder" + + sql, args, _ := ToSQL(Neq{"a":1}) + // a<>? [1] + sql, args, _ := ToSQL(Neq{"b":"c"}.And(Neq{"c": 0})) + // b<>? AND c<>? ["c", 0] + sql, args, _ := ToSQL(Neq{"b":"c", "c":0}) + // b<>? AND c<>? ["c", 0] + sql, args, _ := ToSQL(Neq{"b":"c"}.Or(Neq{"b":"d"})) + // b<>? OR b<>? ["c", "d"] + sql, args, _ := ToSQL(Neq{"b": []string{"c", "d"}}) + // b NOT IN (?,?) ["c", "d"] + sql, args, _ := ToSQL(Neq{"b": 1, "c":[]int{2, 3}}) + // b<>? AND c NOT IN (?,?) [1, 2, 3] + +3. Gt, Gte, Lt, Lte + + import . "xorm.io/builder" + + sql, args, _ := ToSQL(Gt{"a", 1}.And(Gte{"b", 2})) + // a>? AND b>=? [1, 2] + sql, args, _ := ToSQL(Lt{"a", 1}.Or(Lte{"b", 2})) + // a? [1, %c%, 2] + +9. Or(conds ...Cond), Or can connect one or more conditions via Or + + import . "xorm.io/builder" + + sql, args, _ := ToSQL(Or(Eq{"a":1}, Like{"b", "c"}, Neq{"d", 2})) + // a=? OR b LIKE ? OR d<>? [1, %c%, 2] + sql, args, _ := ToSQL(Or(Eq{"a":1}, And(Like{"b", "c"}, Neq{"d", 2}))) + // a=? OR (b LIKE ? AND d<>?) [1, %c%, 2] + +10. Between + + import . "xorm.io/builder" + + sql, args, _ := ToSQL(Between("a", 1, 2)) + // a BETWEEN 1 AND 2 + +11. define yourself conditions +Since Cond is a interface, you can define yourself conditions and compare with them +*/ +package builder diff --git a/builder/error.go b/builder/error.go new file mode 100644 index 0000000..b0ded29 --- /dev/null +++ b/builder/error.go @@ -0,0 +1,40 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import "errors" + +var ( + // ErrNotSupportType not supported SQL type error + ErrNotSupportType = errors.New("Not supported SQL type") + // ErrNoNotInConditions no NOT IN params error + ErrNoNotInConditions = errors.New("No NOT IN conditions") + // ErrNoInConditions no IN params error + ErrNoInConditions = errors.New("No IN conditions") + // ErrNeedMoreArguments need more arguments + ErrNeedMoreArguments = errors.New("Need more sql arguments") + // ErrNoTableName no table name + ErrNoTableName = errors.New("No table indicated") + // ErrNoColumnToUpdate no column to update + ErrNoColumnToUpdate = errors.New("No column(s) to update") + // ErrNoColumnToInsert no column to insert + ErrNoColumnToInsert = errors.New("No column(s) to insert") + // ErrNotSupportDialectType not supported dialect type error + ErrNotSupportDialectType = errors.New("Not supported dialect type") + // ErrNotUnexpectedUnionConditions using union in a wrong way + ErrNotUnexpectedUnionConditions = errors.New("Unexpected conditional fields in UNION query") + // ErrUnsupportedUnionMembers unexpected members in UNION query + ErrUnsupportedUnionMembers = errors.New("Unexpected members in UNION query") + // ErrUnexpectedSubQuery Unexpected sub-query in SELECT query + ErrUnexpectedSubQuery = errors.New("Unexpected sub-query in SELECT query") + // ErrDialectNotSetUp dialect is not setup yet + ErrDialectNotSetUp = errors.New("Dialect is not setup yet, try to use `Dialect(dbType)` at first") + // ErrInvalidLimitation offset or limit is not correct + ErrInvalidLimitation = errors.New("Offset or limit is not correct") + // ErrUnnamedDerivedTable Every derived table must have its own alias + ErrUnnamedDerivedTable = errors.New("Every derived table must have its own alias") + // ErrInconsistentDialect Inconsistent dialect in same builder + ErrInconsistentDialect = errors.New("Inconsistent dialect in same builder") +) diff --git a/builder/sql.go b/builder/sql.go new file mode 100644 index 0000000..479e6b6 --- /dev/null +++ b/builder/sql.go @@ -0,0 +1,171 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + sql2 "database/sql" + "fmt" + "reflect" + "strings" + "time" +) + +func condToSQL(cond Cond) (string, []interface{}, error) { + if cond == nil || !cond.IsValid() { + return "", nil, nil + } + + w := NewWriter() + if err := cond.WriteTo(w); err != nil { + return "", nil, err + } + return w.String(), w.args, nil +} + +func condToBoundSQL(cond Cond) (string, error) { + if cond == nil || !cond.IsValid() { + return "", nil + } + + w := NewWriter() + if err := cond.WriteTo(w); err != nil { + return "", err + } + return ConvertToBoundSQL(w.String(), w.args) +} + +// ToSQL convert a builder or conditions to SQL and args +func ToSQL(cond interface{}) (string, []interface{}, error) { + switch cond.(type) { + case Cond: + return condToSQL(cond.(Cond)) + case *Builder: + return cond.(*Builder).ToSQL() + } + return "", nil, ErrNotSupportType +} + +// ToBoundSQL convert a builder or conditions to parameters bound SQL +func ToBoundSQL(cond interface{}) (string, error) { + switch cond.(type) { + case Cond: + return condToBoundSQL(cond.(Cond)) + case *Builder: + return cond.(*Builder).ToBoundSQL() + } + return "", ErrNotSupportType +} + +func noSQLQuoteNeeded(a interface{}) bool { + if a == nil { + return false + } + + switch a.(type) { + case Field: + return true + case int, int8, int16, int32, int64: + return true + case uint, uint8, uint16, uint32, uint64: + return true + case float32, float64: + return true + case bool: + return true + case string: + return false + case time.Time, *time.Time: + return false + } + + t := reflect.TypeOf(a) + + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Bool: + return true + case reflect.String: + return false + } + + return false +} + +// ConvertToBoundSQL will convert SQL and args to a bound SQL +func ConvertToBoundSQL(sql string, args []interface{}) (string, error) { + buf := strings.Builder{} + var i, j, start int + for ; i < len(sql); i++ { + if sql[i] == '?' { + _, err := buf.WriteString(sql[start:i]) + if err != nil { + return "", err + } + start = i + 1 + + if len(args) == j { + return "", ErrNeedMoreArguments + } + + arg := args[j] + if namedArg, ok := arg.(sql2.NamedArg); ok { + arg = namedArg.Value + } + + if noSQLQuoteNeeded(arg) { + _, err = fmt.Fprint(&buf, arg) + } else { + // replace ' -> '' (standard replacement) to avoid critical SQL injection, + // NOTICE: may allow some injection like % (or _) in LIKE query + _, err = fmt.Fprintf(&buf, "'%v'", strings.Replace(fmt.Sprintf("%v", arg), "'", + "''", -1)) + } + if err != nil { + return "", err + } + j = j + 1 + } + } + _, err := buf.WriteString(sql[start:]) + if err != nil { + return "", err + } + return buf.String(), nil +} + +// ConvertPlaceholder replaces the place holder ? to $1, $2 ... or :1, :2 ... according prefix +func ConvertPlaceholder(sql, prefix string) (string, error) { + buf := strings.Builder{} + var i, j, start int + var ready = true + for ; i < len(sql); i++ { + if sql[i] == '\'' && i > 0 && sql[i-1] != '\\' { + ready = !ready + } + if ready && sql[i] == '?' { + if _, err := buf.WriteString(sql[start:i]); err != nil { + return "", err + } + + start = i + 1 + j = j + 1 + + if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil { + return "", err + } + } + } + + if _, err := buf.WriteString(sql[start:]); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/builder/writer.go b/builder/writer.go new file mode 100644 index 0000000..fb4fae5 --- /dev/null +++ b/builder/writer.go @@ -0,0 +1,42 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package builder + +import ( + "io" + "strings" +) + +// Writer defines the interface +type Writer interface { + io.Writer + Append(...interface{}) +} + +var _ Writer = NewWriter() + +// BytesWriter implments Writer and save SQL in bytes.Buffer +type BytesWriter struct { + *strings.Builder + args []interface{} +} + +// NewWriter creates a new string writer +func NewWriter() *BytesWriter { + w := &BytesWriter{ + Builder: &strings.Builder{}, + } + return w +} + +// Append appends args to Writer +func (w *BytesWriter) Append(args ...interface{}) { + w.args = append(w.args, args...) +} + +// Args returns args +func (w *BytesWriter) Args() []interface{} { + return w.args +} diff --git a/builder/z_vql_ext.go b/builder/z_vql_ext.go new file mode 100644 index 0000000..7aa20fd --- /dev/null +++ b/builder/z_vql_ext.go @@ -0,0 +1,121 @@ +package builder + +func CondType(cond Cond) string { + switch cond.(type) { + case condAnd: + return "and" + case condOr: + return "or" + case condIn: + return "in" + case condNotIn: + return "not in" + case condIf: + return "if" + } + return "" +} + +func GetCondAnd(cond Cond) condAnd { + switch cond.(type) { + case condAnd: + return cond.(condAnd) + } + return nil +} + +func GetCondOr(cond Cond) condOr { + switch cond.(type) { + case condOr: + return cond.(condOr) + } + return nil +} + +func GetCondEq(cond Cond) Eq { + switch cond.(type) { + case Eq: + return cond.(Eq) + } + return nil +} + +func GetCondNeq(cond Cond) Neq { + switch cond.(type) { + case Neq: + return cond.(Neq) + } + return nil +} + +func GetCondGt(cond Cond) Gt { + switch cond.(type) { + case Gt: + return cond.(Gt) + } + return nil +} + +func GetCondGte(cond Cond) Gte { + switch cond.(type) { + case Gte: + return cond.(Gte) + } + return nil +} + +func GetCondLt(cond Cond) Lt { + switch cond.(type) { + case Lt: + return cond.(Lt) + } + return nil +} + +func GetCondLte(cond Cond) Lte { + switch cond.(type) { + case Lte: + return cond.(Lte) + } + return nil +} + +func GetCondLike(cond Cond) (l Like) { + + switch cond.(type) { + case Like: + return cond.(Like) + } + return +} + +func GetCondIn(cond Cond) (i condIn) { + switch cond.(type) { + case condIn: + return cond.(condIn) + } + return +} + +func GetCondNotIn(cond Cond) (i condNotIn) { + switch cond.(type) { + case condNotIn: + return cond.(condNotIn) + } + return +} + +func GetCondIF(cond Cond) (i condIf) { + switch cond.(type) { + case condIf: + return cond.(condIf) + } + return +} + +func GetCondByBuilder(b Builder) (cond Cond) { + + return b.cond +} + +type Field string diff --git a/types.go b/types.go new file mode 100644 index 0000000..ad8c7d9 --- /dev/null +++ b/types.go @@ -0,0 +1,743 @@ +package vql + +import ( + "encoding/json" + "fmt" + "log" + "reflect" + "strings" + + "git.ouxuan.net/3136352472/vql/builder" + "github.com/spf13/cast" + "github.com/xwb1989/sqlparser" +) + +type CondType string + +const ( + CondTypeCondition CondType = "condition" + CondTypeAnd CondType = "and" + CondTypeOr CondType = "or" +) + +type FieldValue struct { + Type string `json:"type"` // field, value + Field *Field `json:"field,omitempty"` + Value string `json:"value,omitempty"` +} + +type WhereCondition struct { + Condition string `json:"condition"` + Left FieldValue `json:"left"` + Right FieldValue `json:"right"` +} + +type Where struct { + Cond CondType `json:"cond"` + Sub []Where `json:"sub,omitempty"` + Value *WhereCondition `json:"value,omitempty"` +} + +func (w *Where) ToCond() (builder.Cond, error) { + var cond builder.Cond + if w.Cond != CondTypeCondition { + var conds []builder.Cond + for _, v := range w.Sub { + subCond, err := v.ToCond() + if err != nil { + return nil, err + } + conds = append(conds, subCond) + } + + switch w.Cond { + case CondTypeAnd: + cond = builder.And(conds...) + case CondTypeOr: + cond = builder.Or(conds...) + } + } else { + + l := "" + if w.Value.Left.Type == "field" { + l = w.Value.Left.Field.Name + if w.Value.Left.Field.FromTable != "" { + l = fmt.Sprint(w.Value.Left.Field.FromTable, ".", l) + } + } else { + l = w.Value.Left.Value + } + var r interface{} + if w.Value.Right.Type == "field" { + r = w.Value.Right.Field.Name + if w.Value.Right.Field.FromTable != "" { + r = builder.Field(fmt.Sprint(w.Value.Right.Field.FromTable, ".", r)) + } + } else { + r = w.Value.Right.Value + } + switch w.Value.Condition { + case "eq", "equal", "=": + + cond = builder.Eq{l: r} + case "neq", "not equal", "!=": + cond = builder.Neq{l: r} + case "gt", "greater than", ">": + cond = builder.Gt{l: r} + case "gte", "greater than or equal", ">=": + cond = builder.Gte{l: r} + case "lt", "less than", "<": + cond = builder.Lt{l: r} + case "lte", "less than or equal", "<=": + cond = builder.Lte{l: r} + case "like": + cond = builder.Like{l, fmt.Sprint(r)} + case "not like": + cond = builder.Like{fmt.Sprint(l, " not"), fmt.Sprint(r)} + case "in": + cond = builder.In(l, r) + case "not in": + cond = builder.NotIn(l, r) + default: + return nil, fmt.Errorf("unknown condition: %s", w.Value.Condition) + } + } + + return cond, nil +} + +func (w *Where) FromSql(sql string) error { + stmt, err := sqlparser.Parse(sql) + if err != nil { + return err + } + + switch stmt.(type) { + case *sqlparser.Select: + sel := stmt.(*sqlparser.Select) + if sel.Where != nil { + return w.fromExpr(sel.Where.Expr) + } + } + return nil +} +func (w *Where) fromExpr(expr sqlparser.Expr) error { + + // log.Println("expr:", reflect.TypeOf(expr), jsonEncode(expr)) + switch expr.(type) { + case *sqlparser.AndExpr: + var left, right Where + w.Cond = CondTypeAnd + and := expr.(*sqlparser.AndExpr) + if err := left.fromExpr(and.Left); err != nil { + return err + } + w.Sub = append(w.Sub, left) + + if err := right.fromExpr(and.Right); err != nil { + return err + } + w.Sub = append(w.Sub, right) + case *sqlparser.OrExpr: + var left, right Where + w.Cond = CondTypeOr + or := expr.(*sqlparser.OrExpr) + + if err := left.fromExpr(or.Left); err != nil { + return err + } + w.Sub = append(w.Sub, left) + if err := right.fromExpr(or.Right); err != nil { + return err + } + w.Sub = append(w.Sub, right) + case *sqlparser.ParenExpr: + paren := expr.(*sqlparser.ParenExpr) + if err := w.fromExpr(paren.Expr); err != nil { + return err + } + case *sqlparser.ComparisonExpr: + w.Cond = CondTypeCondition + return w.conditionExpr(expr) + default: + return fmt.Errorf("unknown expr: %s", reflect.TypeOf(expr)) + log.Println("unknown expr:", reflect.TypeOf(expr), jsonEncode(expr)) + } + return nil +} + +func (w *Where) conditionExpr(expr sqlparser.Expr) error { + // log.Println("conditionExpr:", reflect.TypeOf(expr), jsonEncode(expr)) + switch expr.(type) { + case *sqlparser.ComparisonExpr: + comp := expr.(*sqlparser.ComparisonExpr) + w.Value = &WhereCondition{} + + w.Value.Condition = "unknown" + switch comp.Operator { + case "=": + w.Value.Condition = "eq" + case "!=": + w.Value.Condition = "neq" + case ">": + w.Value.Condition = "gt" + case ">=": + w.Value.Condition = "gte" + case "<": + w.Value.Condition = "lt" + case "<=": + w.Value.Condition = "lte" + case "like": + w.Value.Condition = "like" + case "not like": + w.Value.Condition = "not like" + case "in": + w.Value.Condition = "in" + case "not in": + w.Value.Condition = "not in" + default: + return fmt.Errorf("unknown operator: %s", comp.Operator) + } + + if leftCol, ok := comp.Left.(*sqlparser.ColName); ok { + w.Value.Left = FieldValue{ + Type: "field", + Field: &Field{ + Name: string(leftCol.Name.String()), + FromTable: string(leftCol.Qualifier.Name.String()), + }, + } + + } else if leftVal, ok := comp.Left.(*sqlparser.SQLVal); ok { + w.Value.Left = FieldValue{ + Type: "value", + Value: string(leftVal.Val), + } + } else { + return fmt.Errorf("unexpected left type: %T", comp.Left) + } + + if rightCol, ok := comp.Right.(*sqlparser.ColName); ok { + w.Value.Right = FieldValue{ + Type: "field", + Field: &Field{ + Name: string(rightCol.Name.String()), + FromTable: string(rightCol.Qualifier.Name.String()), + }, + } + } else if rightVal, ok := comp.Right.(*sqlparser.SQLVal); ok { + w.Value.Right = FieldValue{ + Type: "value", + Value: string(rightVal.Val), + } + } else { + return fmt.Errorf("unexpected right type: %T", comp.Right) + } + default: + return fmt.Errorf("unexpected type: %T", expr) + } + // log.Println("Where:", w) + return nil +} + +func (w *Where) FromCond(cond builder.Cond) error { + if cond == nil { + return fmt.Errorf("condition is nil") + } + + switch builder.CondType(cond) { + case "and": + w.Cond = CondTypeAnd + condAnd := builder.GetCondAnd(cond) + for _, v := range condAnd { + var sub Where + sub.FromCond(v) + w.Sub = append(w.Sub, sub) + } + case "or": + w.Cond = CondTypeOr + condOr := builder.GetCondOr(cond) + for _, v := range condOr { + var sub Where + sub.FromCond(v) + w.Sub = append(w.Sub, sub) + } + default: + w.Cond = "condition" + err := w.condition(cond) + if err != nil { + return err + } + } + return nil +} + +func (w *Where) condition(cond builder.Cond) error { + condMap := map[string]interface{}{} + json.Unmarshal([]byte(jsonEncode(cond)), &condMap) + conditionName := "" + switch cond.(type) { + case builder.Eq: + conditionName = "eq" + case builder.Neq: + conditionName = "neq" + case builder.Gt: + conditionName = "gt" + case builder.Gte: + conditionName = "gte" + case builder.Lt: + conditionName = "lt" + case builder.Lte: + conditionName = "lte" + case builder.Like: + conditionName = "like" + case builder.Between: + conditionName = "between" + default: + ct := builder.CondType(cond) + if ct == "in" { + conditionName = "in" + } else if ct == "not in" { + conditionName = "not in" + } else { + return fmt.Errorf("condition type %s not support", ct) + } + } + if len(condMap) > 1 { + w.Cond = CondTypeCondition + for k, v := range condMap { + var sub Where + sub.Value = &WhereCondition{ + Condition: conditionName, + Left: FieldValue{ + Type: "value", + Value: k, + }, + Right: FieldValue{ + Type: "value", + Value: fmt.Sprint(v), + }, + } + w.Sub = append(w.Sub, sub) + } + } else { + w.Value = &WhereCondition{ + Condition: conditionName, + } + for k, v := range condMap { + w.Value.Left = FieldValue{ + Type: "value", + Value: k, + } + w.Value.Right = FieldValue{ + Type: "value", + Value: fmt.Sprint(v), + } + } + } + return nil +} + +type Field struct { + FromTable string `json:"from_table"` + Name string `json:"name"` + As string `json:"as"` +} + +type From struct { + Table string `json:"table"` + SubQuery *Query `json:"sub_query,omitempty"` + raw string `json:"-"` + As string `json:"as"` +} + +type Join struct { + Type string `json:"type"` + Query *Query `json:"query"` + JoinWhere *Where `json:"join_where,omitempty"` +} + +type Limit struct { + Offset int `json:"offset"` + Count int `json:"count"` +} +type OrderByField struct { + Field Field `json:"field"` + Direction string `json:"desc"` +} +type Query struct { + Dialect DialectType `json:"dialect"` + From From `json:"from"` + Join []Join `json:"join,omitempty"` + Fields []Field `json:"field,omitempty"` + Where *Where `json:"where,omitempty"` + GroupBy []Field `json:"group_by,omitempty"` + Having string `json:"having,omitempty"` + OrderBy []OrderByField `json:"order_by,omitempty"` + Limit *Limit `json:"limit,omitempty"` +} + +func getTheLeftmost(node *sqlparser.JoinTableExpr) (left *sqlparser.AliasedTableExpr) { + if node.LeftExpr == nil { + return nil + } + switch left := node.LeftExpr.(type) { + case *sqlparser.AliasedTableExpr: + return left + case *sqlparser.JoinTableExpr: + return getTheLeftmost(left) + default: + return nil + } +} + +func nodeExtractRawSQL(node sqlparser.SQLNode) string { + buff := sqlparser.NewTrackedBuffer(nil) + node.Format(buff) + return buff.String() +} + +func (q *Query) FromSelect(selectStmt *sqlparser.Select) error { + if selectStmt == nil { + return fmt.Errorf("selectStmt is nil") + } + + if selectStmt.SelectExprs != nil { + for _, v := range selectStmt.SelectExprs { + switch expr := v.(type) { + case *sqlparser.StarExpr: + q.Fields = append(q.Fields, Field{ + Name: "*", + }) + case *sqlparser.AliasedExpr: + // log.Println("selectStmt.AliasedExpr", reflect.TypeOf(expr.Expr), jsonEncode(expr.Expr)) + switch colName := expr.Expr.(type) { + case *sqlparser.ColName: + q.Fields = append(q.Fields, Field{ + FromTable: colName.Qualifier.Name.String(), + Name: string(colName.Name.String()), + As: expr.As.String(), + }) + case *sqlparser.FuncExpr: + name := nodeExtractRawSQL(colName) + q.Fields = append(q.Fields, Field{ + Name: name, + As: expr.As.String(), + }) + } + } + } + } + + if selectStmt.Where != nil { + if q.Where == nil { + q.Where = &Where{} + } + if err := q.Where.fromExpr(selectStmt.Where.Expr); err != nil { + return err + } + } + + if selectStmt.From != nil { + if len(selectStmt.From) > 1 { + return fmt.Errorf("not support multi table") + } + form := selectStmt.From[0] + + switch form := form.(type) { + case *sqlparser.AliasedTableExpr: //未联表 + q.From.Table = string(form.Expr.(sqlparser.TableName).Name.String()) + if form.As.String() != "" { + q.From.As = form.As.String() + } + case *sqlparser.JoinTableExpr: //有联表 + + //确定主表 + leftmost := getTheLeftmost(form) + if leftmost != nil { + // log.Println("leftmost:", jsonEncode(leftmost)) + if leftmost.As.String() != "" { + q.From.As = leftmost.As.String() + } + switch leftmostExpr := leftmost.Expr.(type) { + case sqlparser.TableName: + q.From.Table = string(leftmostExpr.Name.String()) + case *sqlparser.Subquery: + subQuery := &Query{} + if err := subQuery.FromSelect(leftmostExpr.Select.(*sqlparser.Select)); err != nil { + return err + } + q.From.SubQuery = subQuery + } + } else { + //找不到最左边的表 + return fmt.Errorf("can not find the leftmost table") + } + + //确定联表 + for { + if form == nil { + break + } + join := Join{ + Type: strings.ToUpper(strings.Trim(strings.TrimSuffix(strings.Trim(strings.ToLower(form.Join), " "), "join"), " ")), + } + + switch right := form.RightExpr.(type) { + case *sqlparser.AliasedTableExpr: + if right.As.String() != "" { + join.Query = &Query{From: From{As: right.As.String()}} + } + switch rightExpr := right.Expr.(type) { + case sqlparser.TableName: + join.Query.From.Table = string(rightExpr.Name.String()) + case *sqlparser.Subquery: + subQuery := &Query{} + if err := subQuery.FromSelect(rightExpr.Select.(*sqlparser.Select)); err != nil { + return err + } + join.Query.From.SubQuery = subQuery + } + } + + join.JoinWhere = &Where{} + if err := join.JoinWhere.fromExpr(form.Condition.On); err != nil { + return err + } + + q.Join = append(q.Join, join) + + var ok bool + form, ok = form.LeftExpr.(*sqlparser.JoinTableExpr) + if !ok { + break + } + } + } + + // log.Println("selectStmt.From:", reflect.TypeOf(selectStmt.From), jsonEncode(selectStmt.From)) + } + + if selectStmt.GroupBy != nil { + // log.Println("selectStmt.GroupBy:", reflect.TypeOf(selectStmt.GroupBy), jsonEncode(selectStmt.GroupBy)) + for _, v := range selectStmt.GroupBy { + switch colName := v.(type) { + case *sqlparser.ColName: + q.GroupBy = append(q.GroupBy, Field{ + FromTable: colName.Qualifier.Name.String(), + Name: string(colName.Name.String()), + }) + } + } + } + + if selectStmt.OrderBy != nil { + // log.Println("selectStmt.OrderBy:", reflect.TypeOf(selectStmt.OrderBy), jsonEncode(selectStmt.OrderBy)) + for _, v := range selectStmt.OrderBy { + switch colName := v.Expr.(type) { + case *sqlparser.ColName: + q.OrderBy = append(q.OrderBy, OrderByField{ + Field: Field{ + FromTable: colName.Qualifier.Name.String(), + Name: string(colName.Name.String()), + }, + Direction: v.Direction, + }) + } + } + } + + if selectStmt.Having != nil { + // log.Println("selectStmt.Having:", reflect.TypeOf(selectStmt.Having), jsonEncode(selectStmt.Having)) + q.Having = nodeExtractRawSQL(selectStmt.Having.Expr) + } + + if selectStmt.Limit != nil { + // log.Println("selectStmt.Limit:", jsonEncode(selectStmt.Limit)) + if q.Limit == nil { + q.Limit = &Limit{} + } + + switch offset := selectStmt.Limit.Offset.(type) { + case *sqlparser.SQLVal: + q.Limit.Offset = cast.ToInt(fmt.Sprintf("%s", offset.Val)) + } + switch count := selectStmt.Limit.Rowcount.(type) { + case *sqlparser.SQLVal: + q.Limit.Count = cast.ToInt(fmt.Sprintf("%s", count.Val)) + } + } + // log.Println("query:", jsonEncode(q)) + return nil +} + +func (q *Query) FromSql(sql string) error { + stmt, err := sqlparser.Parse(sql) + if err != nil { + return err + } + switch stmt.(type) { + case *sqlparser.Select: + selectStmt := stmt.(*sqlparser.Select) + return q.FromSelect(selectStmt) + default: + return fmt.Errorf("unexpected type: %T", stmt) + } +} + +type DialectType string + +const ( + POSTGRES DialectType = "postgres" + SQLITE = "sqlite3" + MYSQL = "mysql" + MSSQL = "mssql" + ORACLE = "oracle" + UNION = "union" + INTERSECT = "intersect" + EXCEPT = "except" +) + +func (q *Query) ToBuilder() (*builder.Builder, error) { + var dialect DialectType = MYSQL + if q.Dialect != "" { + dialect = q.Dialect + } + + fields := []string{} + for i := range q.Fields { + name := q.Fields[i].Name + + if q.Fields[i].FromTable != "" { + name = fmt.Sprintf("%s.`%s`", q.Fields[i].FromTable, name) + } + if q.Fields[i].As != "" { + name = fmt.Sprintf("%s AS %s", name, q.Fields[i].As) + } + fields = append(fields, name) + } + b := builder.Dialect(string(dialect)).Select(fields...) + + if q.From.raw != "" { + if q.From.As != "" { + b = b.From(fmt.Sprintf("(%s)", q.From.raw), q.From.As) + } else { + b = b.From(fmt.Sprintf("(%s)", q.From.raw)) + } + } else if q.From.Table != "" { + if q.From.As != "" { + b = b.From(q.From.Table, q.From.As) + } else { + b = b.From(q.From.Table) + } + } else { + if q.From.SubQuery != nil { + subBuilder, err := q.From.SubQuery.ToBuilder() + if err != nil { + return nil, err + } + if q.From.As != "" { + b = b.From(subBuilder, q.From.As) + } else { + b = b.From(subBuilder) + } + } + } + + if q.Join != nil { + for _, join := range q.Join { + if join.Query != nil { + // log.Println("join:", jsonEncode(join)) + if join.Query.From.As == "" { + return nil, fmt.Errorf("join table must have an alias") + } + + if join.Query.From.raw != "" { + + cond, err := join.JoinWhere.ToCond() + if err != nil { + return nil, err + } + + b = b.Join(join.Type, fmt.Sprintf("(%s)", join.Query.From.raw), cond, join.Query.From.As) + + } else if join.Query.From.SubQuery != nil { + subBuilder, err := join.Query.From.SubQuery.ToBuilder() + if err != nil { + return nil, err + } + cond, err := join.JoinWhere.ToCond() + if err != nil { + return nil, err + } + b = b.Join(join.Type, subBuilder, cond, join.Query.From.As) + } else { + cond, err := join.JoinWhere.ToCond() + if err != nil { + return nil, err + } + b = b.Join(join.Type, join.Query.From.Table, cond, join.Query.From.As) + } + } + } + } + + if q.GroupBy != nil { + groupBy := []string{} + for i := range q.GroupBy { + name := q.GroupBy[i].Name + if q.GroupBy[i].FromTable != "" { + name = fmt.Sprintf("%s.`%s`", q.GroupBy[i].FromTable, name) + } + groupBy = append(groupBy, name) + } + + b = b.GroupBy(fmt.Sprintf("( %s )", strings.Join(groupBy, ", "))) + } + + if q.OrderBy != nil { + orderBy := []string{} + for i := range q.OrderBy { + name := q.OrderBy[i].Field.Name + if q.OrderBy[i].Field.FromTable != "" { + name = fmt.Sprintf("%s.`%s`", q.OrderBy[i].Field.FromTable, name) + } + if q.OrderBy[i].Direction != "" { + name = fmt.Sprintf("%s %s", name, q.OrderBy[i].Direction) + } + orderBy = append(orderBy, name) + } + + b = b.OrderBy(fmt.Sprintf("%s", strings.Join(orderBy, ", "))) + } + + if q.Where != nil { + cond, err := q.Where.ToCond() + if err != nil { + return nil, err + } + b = b.Where(cond) + } + + if q.Having != "" { + + b = b.Having(q.Having) + } + + if q.Limit != nil { + b = b.Limit(q.Limit.Count, q.Limit.Offset) + } + return b, nil +} + +func (q *Query) ToSql() (string, error) { + b, err := q.ToBuilder() + if err != nil { + return "", fmt.Errorf("failed to build query: %w", err) + } + + buf := builder.NewWriter() + b.WriteTo(buf) + log.Println("NewWriter:", buf.String()) + return b.ToBoundSQL() +} diff --git a/vql.go b/vql.go new file mode 100644 index 0000000..357093b --- /dev/null +++ b/vql.go @@ -0,0 +1,166 @@ +package vql + +import ( + "encoding/json" + "fmt" + "strings" + "sync" +) + +type VirtualQlConvert func(val map[string]interface{}) (string, error) + +func NoParameConvertFunc(vTable string, convertSql string) *VirtualTable { + return &VirtualTable{ + MappingTableName: vTable, + SqlConvert: func(val map[string]interface{}) (string, error) { + return convertSql, nil + }, + } +} + +type VirtualTable struct { + MappingTableName string //example: user(db:db1,id:1) + //限定条件 + SqlConvert VirtualQlConvert +} + +func parseTableName(tableNameExpression string) (tableName string, args map[string]string, err error) { + args = make(map[string]string) + source := []rune(tableNameExpression) + + flag := "table" + argStr := "" + for i := range source { + switch flag { + case "table": + if source[i] == '(' { + flag = "arg" + continue + } + tableName += string(source[i]) + case "arg": + if source[i] == ')' { + flag = "end" + continue + } + argStr += string(source[i]) + case "end": + break + } + } + + argsArr := []string{} + if argStr != "" { + argsArr = strings.Split(strings.TrimSpace(argStr), ",") + } + for i := range argsArr { + arg := strings.Split(argsArr[i], ":") + if len(arg) != 2 { + err = fmt.Errorf("table name expression error") + return + } + args[arg[0]] = arg[1] + } + + return +} + +var DefaultVirtualQL = &VirtualQL{} + +type VirtualQL struct { + tables map[string]VirtualTable + locker sync.RWMutex +} + +func (vql *VirtualQL) getTable(tableName string) (vt VirtualTable, args map[string]string, err error) { + vql.locker.RLock() + defer vql.locker.RUnlock() + + tableName, args, err = parseTableName(tableName) + vqlTable := vql.tables[tableName] + return vqlTable, args, err +} + +func (vql *VirtualQL) Register(v *VirtualTable) { + if vql.tables == nil { + vql.tables = make(map[string]VirtualTable) + } + vql.locker.Lock() + defer vql.locker.Unlock() + vql.tables[v.MappingTableName] = *v +} + +func (vql *VirtualQL) Compile(sql string, param map[string]interface{}) (*Query, error) { + query, err := Analyze(sql) + if err != nil { + return nil, err + } + err = vql.convert(query, param) + if err != nil { + return nil, err + } + return query, nil +} + +func (vql *VirtualQL) convert(query *Query, val map[string]interface{}) error { + if query.From.Table == "" && query.From.SubQuery == nil { + return fmt.Errorf("from table is empty") + } + if query.From.Table != "" { + + vtable, args, err := vql.getTable(query.From.Table) + if err != nil { + return err + } + if vtable.SqlConvert == nil { + return fmt.Errorf("table %s is not register", query.From.Table) + } + + if val == nil { + val = make(map[string]interface{}) + } + + for k, v := range args { + val[k] = v + } + + convertSql, err := vtable.SqlConvert(val) + if err != nil { + return err + } + + newAs := query.From.As + if newAs == "" { + newAs = query.From.Table + } + query.From = From{ + raw: convertSql, + As: newAs, + } + } else { + err := vql.convert(query.From.SubQuery, val) + if err != nil { + return err + } + } + + for i := range query.Join { + err := vql.convert(query.Join[i].Query, val) + if err != nil { + return err + } + } + + return nil +} + +func jsonEncode(val interface{}) string { + raw, _ := json.Marshal(val) + return string(raw) +} + +func Analyze(sql string) (*Query, error) { + q := Query{} + err := q.FromSql(sql) + return &q, err +}