package vql import ( "encoding/json" "fmt" "log" "reflect" "strings" "git.ouxuan.net/3136352472/vql/builder" "github.com/spf13/cast" "github.com/xwb1989/sqlparser" ) type CondType string const ( CondTypeCondition CondType = "condition" CondTypeAnd CondType = "and" CondTypeOr CondType = "or" ) type FieldValue struct { Type string `json:"type"` // field, value Field *Field `json:"field,omitempty"` Value string `json:"value,omitempty"` } type WhereCondition struct { Condition string `json:"condition"` Left FieldValue `json:"left"` Right FieldValue `json:"right"` } type Where struct { Cond CondType `json:"cond"` Sub []Where `json:"sub,omitempty"` Value *WhereCondition `json:"value,omitempty"` } func (w *Where) ToCond() (builder.Cond, error) { var cond builder.Cond if w.Cond != CondTypeCondition { var conds []builder.Cond for _, v := range w.Sub { subCond, err := v.ToCond() if err != nil { return nil, err } conds = append(conds, subCond) } switch w.Cond { case CondTypeAnd: cond = builder.And(conds...) case CondTypeOr: cond = builder.Or(conds...) } } else { l := "" if w.Value.Left.Type == "field" { l = w.Value.Left.Field.Name if w.Value.Left.Field.FromTable != "" { l = fmt.Sprint(w.Value.Left.Field.FromTable, ".", l) } } else { l = w.Value.Left.Value } var r interface{} if w.Value.Right.Type == "field" { r = w.Value.Right.Field.Name if w.Value.Right.Field.FromTable != "" { r = builder.Field(fmt.Sprint(w.Value.Right.Field.FromTable, ".", r)) } } else { r = w.Value.Right.Value } switch w.Value.Condition { case "eq", "equal", "=": cond = builder.Eq{l: r} case "neq", "not equal", "!=": cond = builder.Neq{l: r} case "gt", "greater than", ">": cond = builder.Gt{l: r} case "gte", "greater than or equal", ">=": cond = builder.Gte{l: r} case "lt", "less than", "<": cond = builder.Lt{l: r} case "lte", "less than or equal", "<=": cond = builder.Lte{l: r} case "like": cond = builder.Like{l, fmt.Sprint(r)} case "not like": cond = builder.Like{fmt.Sprint(l, " not"), fmt.Sprint(r)} case "in": cond = builder.In(l, r) case "not in": cond = builder.NotIn(l, r) default: return nil, fmt.Errorf("unknown condition: %s", w.Value.Condition) } } return cond, nil } func (w *Where) FromSql(sql string) error { stmt, err := sqlparser.Parse(sql) if err != nil { return err } switch stmt.(type) { case *sqlparser.Select: sel := stmt.(*sqlparser.Select) if sel.Where != nil { return w.fromExpr(sel.Where.Expr) } } return nil } func (w *Where) fromExpr(expr sqlparser.Expr) error { // log.Println("expr:", reflect.TypeOf(expr), jsonEncode(expr)) switch expr.(type) { case *sqlparser.AndExpr: var left, right Where w.Cond = CondTypeAnd and := expr.(*sqlparser.AndExpr) if err := left.fromExpr(and.Left); err != nil { return err } w.Sub = append(w.Sub, left) if err := right.fromExpr(and.Right); err != nil { return err } w.Sub = append(w.Sub, right) case *sqlparser.OrExpr: var left, right Where w.Cond = CondTypeOr or := expr.(*sqlparser.OrExpr) if err := left.fromExpr(or.Left); err != nil { return err } w.Sub = append(w.Sub, left) if err := right.fromExpr(or.Right); err != nil { return err } w.Sub = append(w.Sub, right) case *sqlparser.ParenExpr: paren := expr.(*sqlparser.ParenExpr) if err := w.fromExpr(paren.Expr); err != nil { return err } case *sqlparser.ComparisonExpr: w.Cond = CondTypeCondition return w.conditionExpr(expr) default: return fmt.Errorf("unknown expr: %s", reflect.TypeOf(expr)) log.Println("unknown expr:", reflect.TypeOf(expr), jsonEncode(expr)) } return nil } func (w *Where) conditionExpr(expr sqlparser.Expr) error { // log.Println("conditionExpr:", reflect.TypeOf(expr), jsonEncode(expr)) switch expr.(type) { case *sqlparser.ComparisonExpr: comp := expr.(*sqlparser.ComparisonExpr) w.Value = &WhereCondition{} w.Value.Condition = "unknown" switch comp.Operator { case "=": w.Value.Condition = "eq" case "!=": w.Value.Condition = "neq" case ">": w.Value.Condition = "gt" case ">=": w.Value.Condition = "gte" case "<": w.Value.Condition = "lt" case "<=": w.Value.Condition = "lte" case "like": w.Value.Condition = "like" case "not like": w.Value.Condition = "not like" case "in": w.Value.Condition = "in" case "not in": w.Value.Condition = "not in" default: return fmt.Errorf("unknown operator: %s", comp.Operator) } if leftCol, ok := comp.Left.(*sqlparser.ColName); ok { w.Value.Left = FieldValue{ Type: "field", Field: &Field{ Name: string(leftCol.Name.String()), FromTable: string(leftCol.Qualifier.Name.String()), }, } } else if leftVal, ok := comp.Left.(*sqlparser.SQLVal); ok { w.Value.Left = FieldValue{ Type: "value", Value: string(leftVal.Val), } } else { return fmt.Errorf("unexpected left type: %T", comp.Left) } if rightCol, ok := comp.Right.(*sqlparser.ColName); ok { w.Value.Right = FieldValue{ Type: "field", Field: &Field{ Name: string(rightCol.Name.String()), FromTable: string(rightCol.Qualifier.Name.String()), }, } } else if rightVal, ok := comp.Right.(*sqlparser.SQLVal); ok { w.Value.Right = FieldValue{ Type: "value", Value: string(rightVal.Val), } } else { return fmt.Errorf("unexpected right type: %T", comp.Right) } default: return fmt.Errorf("unexpected type: %T", expr) } // log.Println("Where:", w) return nil } func (w *Where) FromCond(cond builder.Cond) error { if cond == nil { return fmt.Errorf("condition is nil") } switch builder.CondType(cond) { case "and": w.Cond = CondTypeAnd condAnd := builder.GetCondAnd(cond) for _, v := range condAnd { var sub Where sub.FromCond(v) w.Sub = append(w.Sub, sub) } case "or": w.Cond = CondTypeOr condOr := builder.GetCondOr(cond) for _, v := range condOr { var sub Where sub.FromCond(v) w.Sub = append(w.Sub, sub) } default: w.Cond = "condition" err := w.condition(cond) if err != nil { return err } } return nil } func (w *Where) condition(cond builder.Cond) error { condMap := map[string]interface{}{} json.Unmarshal([]byte(jsonEncode(cond)), &condMap) conditionName := "" switch cond.(type) { case builder.Eq: conditionName = "eq" case builder.Neq: conditionName = "neq" case builder.Gt: conditionName = "gt" case builder.Gte: conditionName = "gte" case builder.Lt: conditionName = "lt" case builder.Lte: conditionName = "lte" case builder.Like: conditionName = "like" case builder.Between: conditionName = "between" default: ct := builder.CondType(cond) if ct == "in" { conditionName = "in" } else if ct == "not in" { conditionName = "not in" } else { return fmt.Errorf("condition type %s not support", ct) } } if len(condMap) > 1 { w.Cond = CondTypeCondition for k, v := range condMap { var sub Where sub.Value = &WhereCondition{ Condition: conditionName, Left: FieldValue{ Type: "value", Value: k, }, Right: FieldValue{ Type: "value", Value: fmt.Sprint(v), }, } w.Sub = append(w.Sub, sub) } } else { w.Value = &WhereCondition{ Condition: conditionName, } for k, v := range condMap { w.Value.Left = FieldValue{ Type: "value", Value: k, } w.Value.Right = FieldValue{ Type: "value", Value: fmt.Sprint(v), } } } return nil } type Field struct { FromTable string `json:"from_table"` Name string `json:"name"` As string `json:"as"` } type From struct { Table string `json:"table"` SubQuery *Query `json:"sub_query,omitempty"` raw string `json:"-"` As string `json:"as"` } type Join struct { Type string `json:"type"` Query *Query `json:"query"` JoinWhere *Where `json:"join_where,omitempty"` } type Limit struct { Offset int `json:"offset"` Count int `json:"count"` } type OrderByField struct { Field Field `json:"field"` Direction string `json:"desc"` } type Query struct { Dialect DialectType `json:"dialect"` From From `json:"from"` Join []Join `json:"join,omitempty"` Fields []Field `json:"field,omitempty"` Where *Where `json:"where,omitempty"` GroupBy []Field `json:"group_by,omitempty"` Having string `json:"having,omitempty"` OrderBy []OrderByField `json:"order_by,omitempty"` Limit *Limit `json:"limit,omitempty"` } func getTheLeftmost(node *sqlparser.JoinTableExpr) (left *sqlparser.AliasedTableExpr) { if node.LeftExpr == nil { return nil } switch left := node.LeftExpr.(type) { case *sqlparser.AliasedTableExpr: return left case *sqlparser.JoinTableExpr: return getTheLeftmost(left) default: return nil } } func nodeExtractRawSQL(node sqlparser.SQLNode) string { buff := sqlparser.NewTrackedBuffer(nil) node.Format(buff) return buff.String() } func (q *Query) FromSelect(selectStmt *sqlparser.Select) error { if selectStmt == nil { return fmt.Errorf("selectStmt is nil") } if selectStmt.SelectExprs != nil { for _, v := range selectStmt.SelectExprs { switch expr := v.(type) { case *sqlparser.StarExpr: // log.Println("selectStmt.StarExpr", reflect.TypeOf(expr), jsonEncode(expr)) q.Fields = append(q.Fields, Field{ Name: "*", FromTable: expr.TableName.Name.String(), }) case *sqlparser.AliasedExpr: // log.Println("selectStmt.AliasedExpr", reflect.TypeOf(expr.Expr), jsonEncode(expr.Expr)) switch colName := expr.Expr.(type) { case *sqlparser.ColName: q.Fields = append(q.Fields, Field{ FromTable: colName.Qualifier.Name.String(), Name: string(colName.Name.String()), As: expr.As.String(), }) case *sqlparser.FuncExpr: name := nodeExtractRawSQL(colName) q.Fields = append(q.Fields, Field{ Name: name, As: expr.As.String(), }) } } } } if selectStmt.Where != nil { if q.Where == nil { q.Where = &Where{} } if err := q.Where.fromExpr(selectStmt.Where.Expr); err != nil { return err } } if selectStmt.From != nil { if len(selectStmt.From) > 1 { return fmt.Errorf("not support multi table") } form := selectStmt.From[0] switch form := form.(type) { case *sqlparser.AliasedTableExpr: //未联表 if form.As.String() != "" { q.From.As = form.As.String() } switch exp := form.Expr.(type) { case sqlparser.TableName: // log.Println("TableName:", jsonEncode(exp)) q.From.Table = string(exp.Name.String()) case *sqlparser.Subquery: subQuery := &Query{} if err := subQuery.FromSelect(exp.Select.(*sqlparser.Select)); err != nil { return err } q.From.SubQuery = subQuery } case *sqlparser.JoinTableExpr: //有联表 // log.Println("JoinTableExpr:", jsonEncode(form)) //确定主表 leftmost := getTheLeftmost(form) if leftmost != nil { // log.Println("leftmost:", jsonEncode(leftmost)) if leftmost.As.String() != "" { q.From.As = leftmost.As.String() } switch leftmostExpr := leftmost.Expr.(type) { case sqlparser.TableName: q.From.Table = string(leftmostExpr.Name.String()) case *sqlparser.Subquery: subQuery := &Query{} if err := subQuery.FromSelect(leftmostExpr.Select.(*sqlparser.Select)); err != nil { return err } q.From.SubQuery = subQuery } } else { //找不到最左边的表 return fmt.Errorf("can not find the leftmost table") } //确定联表 for { if form == nil { break } join := Join{ Type: strings.ToUpper(strings.Trim(strings.TrimSuffix(strings.Trim(strings.ToLower(form.Join), " "), "join"), " ")), } switch right := form.RightExpr.(type) { case *sqlparser.AliasedTableExpr: join.Query = &Query{} if right.As.String() != "" { join.Query.From.As = right.As.String() } switch rightExpr := right.Expr.(type) { case sqlparser.TableName: // log.Println("rightExpr:", jsonEncode(rightExpr)) join.Query.From.Table = string(rightExpr.Name.String()) case *sqlparser.Subquery: subQuery := &Query{} if err := subQuery.FromSelect(rightExpr.Select.(*sqlparser.Select)); err != nil { return err } join.Query.From.SubQuery = subQuery } } join.JoinWhere = &Where{} if err := join.JoinWhere.fromExpr(form.Condition.On); err != nil { return err } q.Join = append(q.Join, join) var ok bool form, ok = form.LeftExpr.(*sqlparser.JoinTableExpr) if !ok { break } } } // log.Println("selectStmt.From:", reflect.TypeOf(selectStmt.From), jsonEncode(selectStmt.From)) } if selectStmt.GroupBy != nil { // log.Println("selectStmt.GroupBy:", reflect.TypeOf(selectStmt.GroupBy), jsonEncode(selectStmt.GroupBy)) for _, v := range selectStmt.GroupBy { switch colName := v.(type) { case *sqlparser.ColName: q.GroupBy = append(q.GroupBy, Field{ FromTable: colName.Qualifier.Name.String(), Name: string(colName.Name.String()), }) } } } if selectStmt.OrderBy != nil { // log.Println("selectStmt.OrderBy:", reflect.TypeOf(selectStmt.OrderBy), jsonEncode(selectStmt.OrderBy)) for _, v := range selectStmt.OrderBy { switch colName := v.Expr.(type) { case *sqlparser.ColName: q.OrderBy = append(q.OrderBy, OrderByField{ Field: Field{ FromTable: colName.Qualifier.Name.String(), Name: string(colName.Name.String()), }, Direction: v.Direction, }) } } } if selectStmt.Having != nil { // log.Println("selectStmt.Having:", reflect.TypeOf(selectStmt.Having), jsonEncode(selectStmt.Having)) q.Having = nodeExtractRawSQL(selectStmt.Having.Expr) } if selectStmt.Limit != nil { // log.Println("selectStmt.Limit:", jsonEncode(selectStmt.Limit)) if q.Limit == nil { q.Limit = &Limit{} } switch offset := selectStmt.Limit.Offset.(type) { case *sqlparser.SQLVal: q.Limit.Offset = cast.ToInt(fmt.Sprintf("%s", offset.Val)) } switch count := selectStmt.Limit.Rowcount.(type) { case *sqlparser.SQLVal: q.Limit.Count = cast.ToInt(fmt.Sprintf("%s", count.Val)) } } // log.Println("query:", jsonEncode(q)) return nil } func (q *Query) FromSql(sql string) error { stmt, err := sqlparser.Parse(sql) if err != nil { return err } switch stmt.(type) { case *sqlparser.Select: selectStmt := stmt.(*sqlparser.Select) return q.FromSelect(selectStmt) default: return fmt.Errorf("unexpected type: %T", stmt) } } type DialectType string const ( POSTGRES DialectType = "postgres" SQLITE = "sqlite3" MYSQL = "mysql" MSSQL = "mssql" ORACLE = "oracle" UNION = "union" INTERSECT = "intersect" EXCEPT = "except" ) func (q *Query) ToBuilder() (*builder.Builder, error) { var dialect DialectType = MYSQL if q.Dialect != "" { dialect = q.Dialect } fields := []string{} for i := range q.Fields { name := q.Fields[i].Name if q.Fields[i].FromTable != "" { name = fmt.Sprintf("%s.`%s`", q.Fields[i].FromTable, name) } if q.Fields[i].As != "" { name = fmt.Sprintf("%s AS %s", name, q.Fields[i].As) } fields = append(fields, name) } b := builder.Dialect(string(dialect)).Select(fields...) if q.From.raw != "" { if q.From.As != "" { b = b.From(fmt.Sprintf("(%s)", q.From.raw), q.From.As) } else { b = b.From(fmt.Sprintf("(%s)", q.From.raw)) } } else if q.From.Table != "" { if q.From.As != "" { b = b.From(q.From.Table, q.From.As) } else { b = b.From(q.From.Table) } } else { if q.From.SubQuery != nil { subBuilder, err := q.From.SubQuery.ToBuilder() if err != nil { return nil, err } if q.From.As != "" { b = b.From(subBuilder, q.From.As) } else { b = b.From(subBuilder) } } } if q.Join != nil { for _, join := range q.Join { if join.Query != nil { // log.Println("join:", jsonEncode(join)) if join.Query.From.As == "" { return nil, fmt.Errorf("join table must have an alias") } if join.Query.From.raw != "" { cond, err := join.JoinWhere.ToCond() if err != nil { return nil, err } b = b.Join(join.Type, fmt.Sprintf("(%s)", join.Query.From.raw), cond, join.Query.From.As) } else if join.Query.From.SubQuery != nil { subBuilder, err := join.Query.From.SubQuery.ToBuilder() if err != nil { return nil, err } cond, err := join.JoinWhere.ToCond() if err != nil { return nil, err } b = b.Join(join.Type, subBuilder, cond, join.Query.From.As) } else { cond, err := join.JoinWhere.ToCond() if err != nil { return nil, err } b = b.Join(join.Type, join.Query.From.Table, cond, join.Query.From.As) } } } } if q.GroupBy != nil { groupBy := []string{} for i := range q.GroupBy { name := q.GroupBy[i].Name if q.GroupBy[i].FromTable != "" { name = fmt.Sprintf("%s.`%s`", q.GroupBy[i].FromTable, name) } groupBy = append(groupBy, name) } b = b.GroupBy(fmt.Sprintf("( %s )", strings.Join(groupBy, ", "))) } if q.OrderBy != nil { orderBy := []string{} for i := range q.OrderBy { name := q.OrderBy[i].Field.Name if q.OrderBy[i].Field.FromTable != "" { name = fmt.Sprintf("%s.`%s`", q.OrderBy[i].Field.FromTable, name) } if q.OrderBy[i].Direction != "" { name = fmt.Sprintf("%s %s", name, q.OrderBy[i].Direction) } orderBy = append(orderBy, name) } b = b.OrderBy(fmt.Sprintf("%s", strings.Join(orderBy, ", "))) } if q.Where != nil { cond, err := q.Where.ToCond() if err != nil { return nil, err } b = b.Where(cond) } if q.Having != "" { b = b.Having(q.Having) } if q.Limit != nil { b = b.Limit(q.Limit.Count, q.Limit.Offset) } return b, nil } func (q *Query) ToSql() (string, error) { b, err := q.ToBuilder() if err != nil { return "", fmt.Errorf("failed to build query: %w", err) } return b.ToBoundSQL() }