|
|
package vql
import ( "encoding/json" "fmt" "log" "reflect" "strings"
"git.ouxuan.net/3136352472/vql/builder" "github.com/spf13/cast" "github.com/xwb1989/sqlparser" )
type CondType string
const ( CondTypeCondition CondType = "condition" CondTypeAnd CondType = "and" CondTypeOr CondType = "or" )
type FieldValue struct { Type string `json:"type"` // field, value
Field *Field `json:"field,omitempty"` Value string `json:"value,omitempty"` }
type WhereCondition struct { Condition string `json:"condition"` Left FieldValue `json:"left"` Right FieldValue `json:"right"` }
type Where struct { Cond CondType `json:"cond"` Sub []Where `json:"sub,omitempty"` Value *WhereCondition `json:"value,omitempty"` }
func (w *Where) ToCond() (builder.Cond, error) { var cond builder.Cond if w.Cond != CondTypeCondition { var conds []builder.Cond for _, v := range w.Sub { subCond, err := v.ToCond() if err != nil { return nil, err } conds = append(conds, subCond) }
switch w.Cond { case CondTypeAnd: cond = builder.And(conds...) case CondTypeOr: cond = builder.Or(conds...) } } else {
l := "" if w.Value.Left.Type == "field" { l = w.Value.Left.Field.Name if w.Value.Left.Field.FromTable != "" { l = fmt.Sprint(w.Value.Left.Field.FromTable, ".", l) } } else { l = w.Value.Left.Value } var r interface{} if w.Value.Right.Type == "field" { r = w.Value.Right.Field.Name if w.Value.Right.Field.FromTable != "" { r = builder.Field(fmt.Sprint(w.Value.Right.Field.FromTable, ".", r)) } } else { r = w.Value.Right.Value } switch w.Value.Condition { case "eq", "equal", "=":
cond = builder.Eq{l: r} case "neq", "not equal", "!=": cond = builder.Neq{l: r} case "gt", "greater than", ">": cond = builder.Gt{l: r} case "gte", "greater than or equal", ">=": cond = builder.Gte{l: r} case "lt", "less than", "<": cond = builder.Lt{l: r} case "lte", "less than or equal", "<=": cond = builder.Lte{l: r} case "like": cond = builder.Like{l, fmt.Sprint(r)} case "not like": cond = builder.Like{fmt.Sprint(l, " not"), fmt.Sprint(r)} case "in": cond = builder.In(l, r) case "not in": cond = builder.NotIn(l, r) default: return nil, fmt.Errorf("unknown condition: %s", w.Value.Condition) } }
return cond, nil }
func (w *Where) FromSql(sql string) error { stmt, err := sqlparser.Parse(sql) if err != nil { return err }
switch stmt.(type) { case *sqlparser.Select: sel := stmt.(*sqlparser.Select) if sel.Where != nil { return w.fromExpr(sel.Where.Expr) } } return nil } func (w *Where) fromExpr(expr sqlparser.Expr) error {
// log.Println("expr:", reflect.TypeOf(expr), jsonEncode(expr))
switch expr.(type) { case *sqlparser.AndExpr: var left, right Where w.Cond = CondTypeAnd and := expr.(*sqlparser.AndExpr) if err := left.fromExpr(and.Left); err != nil { return err } w.Sub = append(w.Sub, left)
if err := right.fromExpr(and.Right); err != nil { return err } w.Sub = append(w.Sub, right) case *sqlparser.OrExpr: var left, right Where w.Cond = CondTypeOr or := expr.(*sqlparser.OrExpr)
if err := left.fromExpr(or.Left); err != nil { return err } w.Sub = append(w.Sub, left) if err := right.fromExpr(or.Right); err != nil { return err } w.Sub = append(w.Sub, right) case *sqlparser.ParenExpr: paren := expr.(*sqlparser.ParenExpr) if err := w.fromExpr(paren.Expr); err != nil { return err } case *sqlparser.ComparisonExpr: w.Cond = CondTypeCondition return w.conditionExpr(expr) default: return fmt.Errorf("unknown expr: %s", reflect.TypeOf(expr)) log.Println("unknown expr:", reflect.TypeOf(expr), jsonEncode(expr)) } return nil }
func (w *Where) conditionExpr(expr sqlparser.Expr) error { // log.Println("conditionExpr:", reflect.TypeOf(expr), jsonEncode(expr))
switch expr.(type) { case *sqlparser.ComparisonExpr: comp := expr.(*sqlparser.ComparisonExpr) w.Value = &WhereCondition{}
w.Value.Condition = "unknown" switch comp.Operator { case "=": w.Value.Condition = "eq" case "!=": w.Value.Condition = "neq" case ">": w.Value.Condition = "gt" case ">=": w.Value.Condition = "gte" case "<": w.Value.Condition = "lt" case "<=": w.Value.Condition = "lte" case "like": w.Value.Condition = "like" case "not like": w.Value.Condition = "not like" case "in": w.Value.Condition = "in" case "not in": w.Value.Condition = "not in" default: return fmt.Errorf("unknown operator: %s", comp.Operator) }
if leftCol, ok := comp.Left.(*sqlparser.ColName); ok { w.Value.Left = FieldValue{ Type: "field", Field: &Field{ Name: string(leftCol.Name.String()), FromTable: string(leftCol.Qualifier.Name.String()), }, }
} else if leftVal, ok := comp.Left.(*sqlparser.SQLVal); ok { w.Value.Left = FieldValue{ Type: "value", Value: string(leftVal.Val), } } else { return fmt.Errorf("unexpected left type: %T", comp.Left) }
if rightCol, ok := comp.Right.(*sqlparser.ColName); ok { w.Value.Right = FieldValue{ Type: "field", Field: &Field{ Name: string(rightCol.Name.String()), FromTable: string(rightCol.Qualifier.Name.String()), }, } } else if rightVal, ok := comp.Right.(*sqlparser.SQLVal); ok { w.Value.Right = FieldValue{ Type: "value", Value: string(rightVal.Val), } } else { return fmt.Errorf("unexpected right type: %T", comp.Right) } default: return fmt.Errorf("unexpected type: %T", expr) } // log.Println("Where:", w)
return nil }
func (w *Where) FromCond(cond builder.Cond) error { if cond == nil { return fmt.Errorf("condition is nil") }
switch builder.CondType(cond) { case "and": w.Cond = CondTypeAnd condAnd := builder.GetCondAnd(cond) for _, v := range condAnd { var sub Where sub.FromCond(v) w.Sub = append(w.Sub, sub) } case "or": w.Cond = CondTypeOr condOr := builder.GetCondOr(cond) for _, v := range condOr { var sub Where sub.FromCond(v) w.Sub = append(w.Sub, sub) } default: w.Cond = "condition" err := w.condition(cond) if err != nil { return err } } return nil }
func (w *Where) condition(cond builder.Cond) error { condMap := map[string]interface{}{} json.Unmarshal([]byte(jsonEncode(cond)), &condMap) conditionName := "" switch cond.(type) { case builder.Eq: conditionName = "eq" case builder.Neq: conditionName = "neq" case builder.Gt: conditionName = "gt" case builder.Gte: conditionName = "gte" case builder.Lt: conditionName = "lt" case builder.Lte: conditionName = "lte" case builder.Like: conditionName = "like" case builder.Between: conditionName = "between" default: ct := builder.CondType(cond) if ct == "in" { conditionName = "in" } else if ct == "not in" { conditionName = "not in" } else { return fmt.Errorf("condition type %s not support", ct) } } if len(condMap) > 1 { w.Cond = CondTypeCondition for k, v := range condMap { var sub Where sub.Value = &WhereCondition{ Condition: conditionName, Left: FieldValue{ Type: "value", Value: k, }, Right: FieldValue{ Type: "value", Value: fmt.Sprint(v), }, } w.Sub = append(w.Sub, sub) } } else { w.Value = &WhereCondition{ Condition: conditionName, } for k, v := range condMap { w.Value.Left = FieldValue{ Type: "value", Value: k, } w.Value.Right = FieldValue{ Type: "value", Value: fmt.Sprint(v), } } } return nil }
type Field struct { FromTable string `json:"from_table"` Name string `json:"name"` As string `json:"as"` }
type From struct { Table string `json:"table"` SubQuery *Query `json:"sub_query,omitempty"` raw string `json:"-"` As string `json:"as"` }
type Join struct { Type string `json:"type"` Query *Query `json:"query"` JoinWhere *Where `json:"join_where,omitempty"` }
type Limit struct { Offset int `json:"offset"` Count int `json:"count"` } type OrderByField struct { Field Field `json:"field"` Direction string `json:"desc"` } type Query struct { Dialect DialectType `json:"dialect"` From From `json:"from"` Join []Join `json:"join,omitempty"` Fields []Field `json:"field,omitempty"` Where *Where `json:"where,omitempty"` GroupBy []Field `json:"group_by,omitempty"` Having string `json:"having,omitempty"` OrderBy []OrderByField `json:"order_by,omitempty"` Limit *Limit `json:"limit,omitempty"` }
func getTheLeftmost(node *sqlparser.JoinTableExpr) (left *sqlparser.AliasedTableExpr) { if node.LeftExpr == nil { return nil } switch left := node.LeftExpr.(type) { case *sqlparser.AliasedTableExpr: return left case *sqlparser.JoinTableExpr: return getTheLeftmost(left) default: return nil } }
func nodeExtractRawSQL(node sqlparser.SQLNode) string { buff := sqlparser.NewTrackedBuffer(nil) node.Format(buff) return buff.String() }
func (q *Query) FromSelect(selectStmt *sqlparser.Select) error { if selectStmt == nil { return fmt.Errorf("selectStmt is nil") }
if selectStmt.SelectExprs != nil { for _, v := range selectStmt.SelectExprs { switch expr := v.(type) { case *sqlparser.StarExpr: // log.Println("selectStmt.StarExpr", reflect.TypeOf(expr), jsonEncode(expr))
q.Fields = append(q.Fields, Field{ Name: "*", FromTable: expr.TableName.Name.String(), }) case *sqlparser.AliasedExpr: // log.Println("selectStmt.AliasedExpr", reflect.TypeOf(expr.Expr), jsonEncode(expr.Expr))
switch colName := expr.Expr.(type) { case *sqlparser.ColName: q.Fields = append(q.Fields, Field{ FromTable: colName.Qualifier.Name.String(), Name: string(colName.Name.String()), As: expr.As.String(), }) case *sqlparser.FuncExpr: name := nodeExtractRawSQL(colName) q.Fields = append(q.Fields, Field{ Name: name, As: expr.As.String(), }) } } } }
if selectStmt.Where != nil { if q.Where == nil { q.Where = &Where{} } if err := q.Where.fromExpr(selectStmt.Where.Expr); err != nil { return err } }
if selectStmt.From != nil { if len(selectStmt.From) > 1 { return fmt.Errorf("not support multi table") } form := selectStmt.From[0]
switch form := form.(type) { case *sqlparser.AliasedTableExpr: //未联表
if form.As.String() != "" { q.From.As = form.As.String() } switch exp := form.Expr.(type) { case sqlparser.TableName: // log.Println("TableName:", jsonEncode(exp))
q.From.Table = string(exp.Name.String())
case *sqlparser.Subquery: subQuery := &Query{} if err := subQuery.FromSelect(exp.Select.(*sqlparser.Select)); err != nil { return err } q.From.SubQuery = subQuery } case *sqlparser.JoinTableExpr: //有联表
// log.Println("JoinTableExpr:", jsonEncode(form))
//确定主表
leftmost := getTheLeftmost(form) if leftmost != nil { // log.Println("leftmost:", jsonEncode(leftmost))
if leftmost.As.String() != "" { q.From.As = leftmost.As.String() } switch leftmostExpr := leftmost.Expr.(type) { case sqlparser.TableName: q.From.Table = string(leftmostExpr.Name.String()) case *sqlparser.Subquery: subQuery := &Query{} if err := subQuery.FromSelect(leftmostExpr.Select.(*sqlparser.Select)); err != nil { return err } q.From.SubQuery = subQuery }
} else { //找不到最左边的表
return fmt.Errorf("can not find the leftmost table") }
//确定联表
for { if form == nil { break } join := Join{ Type: strings.ToUpper(strings.Trim(strings.TrimSuffix(strings.Trim(strings.ToLower(form.Join), " "), "join"), " ")), }
switch right := form.RightExpr.(type) { case *sqlparser.AliasedTableExpr: join.Query = &Query{} if right.As.String() != "" { join.Query.From.As = right.As.String() } switch rightExpr := right.Expr.(type) { case sqlparser.TableName: // log.Println("rightExpr:", jsonEncode(rightExpr))
join.Query.From.Table = string(rightExpr.Name.String()) case *sqlparser.Subquery: subQuery := &Query{} if err := subQuery.FromSelect(rightExpr.Select.(*sqlparser.Select)); err != nil { return err } join.Query.From.SubQuery = subQuery } }
join.JoinWhere = &Where{} if err := join.JoinWhere.fromExpr(form.Condition.On); err != nil { return err }
q.Join = append(q.Join, join)
var ok bool form, ok = form.LeftExpr.(*sqlparser.JoinTableExpr) if !ok { break } } }
// log.Println("selectStmt.From:", reflect.TypeOf(selectStmt.From), jsonEncode(selectStmt.From))
}
if selectStmt.GroupBy != nil { // log.Println("selectStmt.GroupBy:", reflect.TypeOf(selectStmt.GroupBy), jsonEncode(selectStmt.GroupBy))
for _, v := range selectStmt.GroupBy { switch colName := v.(type) { case *sqlparser.ColName: q.GroupBy = append(q.GroupBy, Field{ FromTable: colName.Qualifier.Name.String(), Name: string(colName.Name.String()), }) } } }
if selectStmt.OrderBy != nil { // log.Println("selectStmt.OrderBy:", reflect.TypeOf(selectStmt.OrderBy), jsonEncode(selectStmt.OrderBy))
for _, v := range selectStmt.OrderBy { switch colName := v.Expr.(type) { case *sqlparser.ColName: q.OrderBy = append(q.OrderBy, OrderByField{ Field: Field{ FromTable: colName.Qualifier.Name.String(), Name: string(colName.Name.String()), }, Direction: v.Direction, }) } } }
if selectStmt.Having != nil { // log.Println("selectStmt.Having:", reflect.TypeOf(selectStmt.Having), jsonEncode(selectStmt.Having))
q.Having = nodeExtractRawSQL(selectStmt.Having.Expr) }
if selectStmt.Limit != nil { // log.Println("selectStmt.Limit:", jsonEncode(selectStmt.Limit))
if q.Limit == nil { q.Limit = &Limit{} }
switch offset := selectStmt.Limit.Offset.(type) { case *sqlparser.SQLVal: q.Limit.Offset = cast.ToInt(fmt.Sprintf("%s", offset.Val)) } switch count := selectStmt.Limit.Rowcount.(type) { case *sqlparser.SQLVal: q.Limit.Count = cast.ToInt(fmt.Sprintf("%s", count.Val)) } } // log.Println("query:", jsonEncode(q))
return nil }
func (q *Query) FromSql(sql string) error { stmt, err := sqlparser.Parse(sql) if err != nil { return err } switch stmt.(type) { case *sqlparser.Select: selectStmt := stmt.(*sqlparser.Select) return q.FromSelect(selectStmt) default: return fmt.Errorf("unexpected type: %T", stmt) } }
type DialectType string
const ( POSTGRES DialectType = "postgres" SQLITE = "sqlite3" MYSQL = "mysql" MSSQL = "mssql" ORACLE = "oracle" UNION = "union" INTERSECT = "intersect" EXCEPT = "except" )
func (q *Query) ToBuilder() (*builder.Builder, error) { var dialect DialectType = MYSQL if q.Dialect != "" { dialect = q.Dialect }
fields := []string{} for i := range q.Fields { name := q.Fields[i].Name
if q.Fields[i].FromTable != "" { name = fmt.Sprintf("%s.`%s`", q.Fields[i].FromTable, name) } if q.Fields[i].As != "" { name = fmt.Sprintf("%s AS %s", name, q.Fields[i].As) } fields = append(fields, name) } b := builder.Dialect(string(dialect)).Select(fields...)
if q.From.raw != "" { if q.From.As != "" { b = b.From(fmt.Sprintf("(%s)", q.From.raw), q.From.As) } else { b = b.From(fmt.Sprintf("(%s)", q.From.raw)) } } else if q.From.Table != "" { if q.From.As != "" { b = b.From(q.From.Table, q.From.As) } else { b = b.From(q.From.Table) } } else { if q.From.SubQuery != nil { subBuilder, err := q.From.SubQuery.ToBuilder() if err != nil { return nil, err } if q.From.As != "" { b = b.From(subBuilder, q.From.As) } else { b = b.From(subBuilder) } } }
if q.Join != nil { for _, join := range q.Join { if join.Query != nil { // log.Println("join:", jsonEncode(join))
if join.Query.From.As == "" { return nil, fmt.Errorf("join table must have an alias") }
if join.Query.From.raw != "" {
cond, err := join.JoinWhere.ToCond() if err != nil { return nil, err }
b = b.Join(join.Type, fmt.Sprintf("(%s)", join.Query.From.raw), cond, join.Query.From.As)
} else if join.Query.From.SubQuery != nil { subBuilder, err := join.Query.From.SubQuery.ToBuilder() if err != nil { return nil, err } cond, err := join.JoinWhere.ToCond() if err != nil { return nil, err } b = b.Join(join.Type, subBuilder, cond, join.Query.From.As) } else { cond, err := join.JoinWhere.ToCond() if err != nil { return nil, err } b = b.Join(join.Type, join.Query.From.Table, cond, join.Query.From.As) } } } }
if q.GroupBy != nil { groupBy := []string{} for i := range q.GroupBy { name := q.GroupBy[i].Name if q.GroupBy[i].FromTable != "" { name = fmt.Sprintf("%s.`%s`", q.GroupBy[i].FromTable, name) } groupBy = append(groupBy, name) }
b = b.GroupBy(fmt.Sprintf("( %s )", strings.Join(groupBy, ", "))) }
if q.OrderBy != nil { orderBy := []string{} for i := range q.OrderBy { name := q.OrderBy[i].Field.Name if q.OrderBy[i].Field.FromTable != "" { name = fmt.Sprintf("%s.`%s`", q.OrderBy[i].Field.FromTable, name) } if q.OrderBy[i].Direction != "" { name = fmt.Sprintf("%s %s", name, q.OrderBy[i].Direction) } orderBy = append(orderBy, name) }
b = b.OrderBy(fmt.Sprintf("%s", strings.Join(orderBy, ", "))) }
if q.Where != nil { cond, err := q.Where.ToCond() if err != nil { return nil, err } b = b.Where(cond) }
if q.Having != "" {
b = b.Having(q.Having) }
if q.Limit != nil { b = b.Limit(q.Limit.Count, q.Limit.Offset) } return b, nil }
func (q *Query) ToSql() (string, error) { b, err := q.ToBuilder() if err != nil { return "", fmt.Errorf("failed to build query: %w", err) }
return b.ToBoundSQL() }
|