You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

757 lines
18 KiB

package vql
import (
"encoding/json"
"fmt"
"log"
"reflect"
"strings"
"git.ouxuan.net/3136352472/vql/builder"
"github.com/spf13/cast"
"github.com/xwb1989/sqlparser"
)
type CondType string
const (
CondTypeCondition CondType = "condition"
CondTypeAnd CondType = "and"
CondTypeOr CondType = "or"
)
type FieldValue struct {
Type string `json:"type"` // field, value
Field *Field `json:"field,omitempty"`
Value string `json:"value,omitempty"`
}
type WhereCondition struct {
Condition string `json:"condition"`
Left FieldValue `json:"left"`
Right FieldValue `json:"right"`
}
type Where struct {
Cond CondType `json:"cond"`
Sub []Where `json:"sub,omitempty"`
Value *WhereCondition `json:"value,omitempty"`
}
func (w *Where) ToCond() (builder.Cond, error) {
var cond builder.Cond
if w.Cond != CondTypeCondition {
var conds []builder.Cond
for _, v := range w.Sub {
subCond, err := v.ToCond()
if err != nil {
return nil, err
}
conds = append(conds, subCond)
}
switch w.Cond {
case CondTypeAnd:
cond = builder.And(conds...)
case CondTypeOr:
cond = builder.Or(conds...)
}
} else {
l := ""
if w.Value.Left.Type == "field" {
l = w.Value.Left.Field.Name
if w.Value.Left.Field.FromTable != "" {
l = fmt.Sprint(w.Value.Left.Field.FromTable, ".", l)
}
} else {
l = w.Value.Left.Value
}
var r interface{}
if w.Value.Right.Type == "field" {
r = w.Value.Right.Field.Name
if w.Value.Right.Field.FromTable != "" {
r = builder.Field(fmt.Sprint(w.Value.Right.Field.FromTable, ".", r))
}
} else {
r = w.Value.Right.Value
}
switch w.Value.Condition {
case "eq", "equal", "=":
cond = builder.Eq{l: r}
case "neq", "not equal", "!=":
cond = builder.Neq{l: r}
case "gt", "greater than", ">":
cond = builder.Gt{l: r}
case "gte", "greater than or equal", ">=":
cond = builder.Gte{l: r}
case "lt", "less than", "<":
cond = builder.Lt{l: r}
case "lte", "less than or equal", "<=":
cond = builder.Lte{l: r}
case "like":
cond = builder.Like{l, fmt.Sprint(r)}
case "not like":
cond = builder.Like{fmt.Sprint(l, " not"), fmt.Sprint(r)}
case "in":
cond = builder.In(l, r)
case "not in":
cond = builder.NotIn(l, r)
default:
return nil, fmt.Errorf("unknown condition: %s", w.Value.Condition)
}
}
return cond, nil
}
func (w *Where) FromSql(sql string) error {
stmt, err := sqlparser.Parse(sql)
if err != nil {
return err
}
switch stmt.(type) {
case *sqlparser.Select:
sel := stmt.(*sqlparser.Select)
if sel.Where != nil {
return w.fromExpr(sel.Where.Expr)
}
}
return nil
}
func (w *Where) fromExpr(expr sqlparser.Expr) error {
// log.Println("expr:", reflect.TypeOf(expr), jsonEncode(expr))
switch expr.(type) {
case *sqlparser.AndExpr:
var left, right Where
w.Cond = CondTypeAnd
and := expr.(*sqlparser.AndExpr)
if err := left.fromExpr(and.Left); err != nil {
return err
}
w.Sub = append(w.Sub, left)
if err := right.fromExpr(and.Right); err != nil {
return err
}
w.Sub = append(w.Sub, right)
case *sqlparser.OrExpr:
var left, right Where
w.Cond = CondTypeOr
or := expr.(*sqlparser.OrExpr)
if err := left.fromExpr(or.Left); err != nil {
return err
}
w.Sub = append(w.Sub, left)
if err := right.fromExpr(or.Right); err != nil {
return err
}
w.Sub = append(w.Sub, right)
case *sqlparser.ParenExpr:
paren := expr.(*sqlparser.ParenExpr)
if err := w.fromExpr(paren.Expr); err != nil {
return err
}
case *sqlparser.ComparisonExpr:
w.Cond = CondTypeCondition
return w.conditionExpr(expr)
default:
return fmt.Errorf("unknown expr: %s", reflect.TypeOf(expr))
log.Println("unknown expr:", reflect.TypeOf(expr), jsonEncode(expr))
}
return nil
}
func (w *Where) conditionExpr(expr sqlparser.Expr) error {
// log.Println("conditionExpr:", reflect.TypeOf(expr), jsonEncode(expr))
switch expr.(type) {
case *sqlparser.ComparisonExpr:
comp := expr.(*sqlparser.ComparisonExpr)
w.Value = &WhereCondition{}
w.Value.Condition = "unknown"
switch comp.Operator {
case "=":
w.Value.Condition = "eq"
case "!=":
w.Value.Condition = "neq"
case ">":
w.Value.Condition = "gt"
case ">=":
w.Value.Condition = "gte"
case "<":
w.Value.Condition = "lt"
case "<=":
w.Value.Condition = "lte"
case "like":
w.Value.Condition = "like"
case "not like":
w.Value.Condition = "not like"
case "in":
w.Value.Condition = "in"
case "not in":
w.Value.Condition = "not in"
default:
return fmt.Errorf("unknown operator: %s", comp.Operator)
}
if leftCol, ok := comp.Left.(*sqlparser.ColName); ok {
w.Value.Left = FieldValue{
Type: "field",
Field: &Field{
Name: string(leftCol.Name.String()),
FromTable: string(leftCol.Qualifier.Name.String()),
},
}
} else if leftVal, ok := comp.Left.(*sqlparser.SQLVal); ok {
w.Value.Left = FieldValue{
Type: "value",
Value: string(leftVal.Val),
}
} else {
return fmt.Errorf("unexpected left type: %T", comp.Left)
}
if rightCol, ok := comp.Right.(*sqlparser.ColName); ok {
w.Value.Right = FieldValue{
Type: "field",
Field: &Field{
Name: string(rightCol.Name.String()),
FromTable: string(rightCol.Qualifier.Name.String()),
},
}
} else if rightVal, ok := comp.Right.(*sqlparser.SQLVal); ok {
w.Value.Right = FieldValue{
Type: "value",
Value: string(rightVal.Val),
}
} else {
return fmt.Errorf("unexpected right type: %T", comp.Right)
}
default:
return fmt.Errorf("unexpected type: %T", expr)
}
// log.Println("Where:", w)
return nil
}
func (w *Where) FromCond(cond builder.Cond) error {
if cond == nil {
return fmt.Errorf("condition is nil")
}
switch builder.CondType(cond) {
case "and":
w.Cond = CondTypeAnd
condAnd := builder.GetCondAnd(cond)
for _, v := range condAnd {
var sub Where
sub.FromCond(v)
w.Sub = append(w.Sub, sub)
}
case "or":
w.Cond = CondTypeOr
condOr := builder.GetCondOr(cond)
for _, v := range condOr {
var sub Where
sub.FromCond(v)
w.Sub = append(w.Sub, sub)
}
default:
w.Cond = "condition"
err := w.condition(cond)
if err != nil {
return err
}
}
return nil
}
func (w *Where) condition(cond builder.Cond) error {
condMap := map[string]interface{}{}
json.Unmarshal([]byte(jsonEncode(cond)), &condMap)
conditionName := ""
switch cond.(type) {
case builder.Eq:
conditionName = "eq"
case builder.Neq:
conditionName = "neq"
case builder.Gt:
conditionName = "gt"
case builder.Gte:
conditionName = "gte"
case builder.Lt:
conditionName = "lt"
case builder.Lte:
conditionName = "lte"
case builder.Like:
conditionName = "like"
case builder.Between:
conditionName = "between"
default:
ct := builder.CondType(cond)
if ct == "in" {
conditionName = "in"
} else if ct == "not in" {
conditionName = "not in"
} else {
return fmt.Errorf("condition type %s not support", ct)
}
}
if len(condMap) > 1 {
w.Cond = CondTypeCondition
for k, v := range condMap {
var sub Where
sub.Value = &WhereCondition{
Condition: conditionName,
Left: FieldValue{
Type: "value",
Value: k,
},
Right: FieldValue{
Type: "value",
Value: fmt.Sprint(v),
},
}
w.Sub = append(w.Sub, sub)
}
} else {
w.Value = &WhereCondition{
Condition: conditionName,
}
for k, v := range condMap {
w.Value.Left = FieldValue{
Type: "value",
Value: k,
}
w.Value.Right = FieldValue{
Type: "value",
Value: fmt.Sprint(v),
}
}
}
return nil
}
type Field struct {
FromTable string `json:"from_table"`
Name string `json:"name"`
As string `json:"as"`
}
type From struct {
Table string `json:"table"`
SubQuery *Query `json:"sub_query,omitempty"`
raw string `json:"-"`
As string `json:"as"`
}
type Join struct {
Type string `json:"type"`
Query *Query `json:"query"`
JoinWhere *Where `json:"join_where,omitempty"`
}
type Limit struct {
Offset int `json:"offset"`
Count int `json:"count"`
}
type OrderByField struct {
Field Field `json:"field"`
Direction string `json:"desc"`
}
type Query struct {
Dialect DialectType `json:"dialect"`
From From `json:"from"`
Join []Join `json:"join,omitempty"`
Fields []Field `json:"field,omitempty"`
Where *Where `json:"where,omitempty"`
GroupBy []Field `json:"group_by,omitempty"`
Having string `json:"having,omitempty"`
OrderBy []OrderByField `json:"order_by,omitempty"`
Limit *Limit `json:"limit,omitempty"`
}
func getTheLeftmost(node *sqlparser.JoinTableExpr) (left *sqlparser.AliasedTableExpr) {
if node.LeftExpr == nil {
return nil
}
switch left := node.LeftExpr.(type) {
case *sqlparser.AliasedTableExpr:
return left
case *sqlparser.JoinTableExpr:
return getTheLeftmost(left)
default:
return nil
}
}
func nodeExtractRawSQL(node sqlparser.SQLNode) string {
buff := sqlparser.NewTrackedBuffer(nil)
node.Format(buff)
return buff.String()
}
func (q *Query) FromSelect(selectStmt *sqlparser.Select) error {
if selectStmt == nil {
return fmt.Errorf("selectStmt is nil")
}
if selectStmt.SelectExprs != nil {
for _, v := range selectStmt.SelectExprs {
switch expr := v.(type) {
case *sqlparser.StarExpr:
// log.Println("selectStmt.StarExpr", reflect.TypeOf(expr), jsonEncode(expr))
q.Fields = append(q.Fields, Field{
Name: "*",
FromTable: expr.TableName.Name.String(),
})
case *sqlparser.AliasedExpr:
// log.Println("selectStmt.AliasedExpr", reflect.TypeOf(expr.Expr), jsonEncode(expr.Expr))
switch colName := expr.Expr.(type) {
case *sqlparser.ColName:
q.Fields = append(q.Fields, Field{
FromTable: colName.Qualifier.Name.String(),
Name: string(colName.Name.String()),
As: expr.As.String(),
})
case *sqlparser.FuncExpr:
name := nodeExtractRawSQL(colName)
q.Fields = append(q.Fields, Field{
Name: name,
As: expr.As.String(),
})
}
}
}
}
if selectStmt.Where != nil {
if q.Where == nil {
q.Where = &Where{}
}
if err := q.Where.fromExpr(selectStmt.Where.Expr); err != nil {
return err
}
}
if selectStmt.From != nil {
if len(selectStmt.From) > 1 {
return fmt.Errorf("not support multi table")
}
form := selectStmt.From[0]
switch form := form.(type) {
case *sqlparser.AliasedTableExpr: //未联表
if form.As.String() != "" {
q.From.As = form.As.String()
}
switch exp := form.Expr.(type) {
case sqlparser.TableName:
// log.Println("TableName:", jsonEncode(exp))
q.From.Table = string(exp.Name.String())
case *sqlparser.Subquery:
subQuery := &Query{}
if err := subQuery.FromSelect(exp.Select.(*sqlparser.Select)); err != nil {
return err
}
q.From.SubQuery = subQuery
}
case *sqlparser.JoinTableExpr: //有联表
// log.Println("JoinTableExpr:", jsonEncode(form))
//确定主表
leftmost := getTheLeftmost(form)
if leftmost != nil {
// log.Println("leftmost:", jsonEncode(leftmost))
if leftmost.As.String() != "" {
q.From.As = leftmost.As.String()
}
switch leftmostExpr := leftmost.Expr.(type) {
case sqlparser.TableName:
q.From.Table = string(leftmostExpr.Name.String())
case *sqlparser.Subquery:
subQuery := &Query{}
if err := subQuery.FromSelect(leftmostExpr.Select.(*sqlparser.Select)); err != nil {
return err
}
q.From.SubQuery = subQuery
}
} else {
//找不到最左边的表
return fmt.Errorf("can not find the leftmost table")
}
//确定联表
for {
if form == nil {
break
}
join := Join{
Type: strings.ToUpper(strings.Trim(strings.TrimSuffix(strings.Trim(strings.ToLower(form.Join), " "), "join"), " ")),
}
switch right := form.RightExpr.(type) {
case *sqlparser.AliasedTableExpr:
join.Query = &Query{}
if right.As.String() != "" {
join.Query.From.As = right.As.String()
}
switch rightExpr := right.Expr.(type) {
case sqlparser.TableName:
// log.Println("rightExpr:", jsonEncode(rightExpr))
join.Query.From.Table = string(rightExpr.Name.String())
case *sqlparser.Subquery:
subQuery := &Query{}
if err := subQuery.FromSelect(rightExpr.Select.(*sqlparser.Select)); err != nil {
return err
}
join.Query.From.SubQuery = subQuery
}
}
join.JoinWhere = &Where{}
if err := join.JoinWhere.fromExpr(form.Condition.On); err != nil {
return err
}
q.Join = append(q.Join, join)
var ok bool
form, ok = form.LeftExpr.(*sqlparser.JoinTableExpr)
if !ok {
break
}
}
}
// log.Println("selectStmt.From:", reflect.TypeOf(selectStmt.From), jsonEncode(selectStmt.From))
}
if selectStmt.GroupBy != nil {
// log.Println("selectStmt.GroupBy:", reflect.TypeOf(selectStmt.GroupBy), jsonEncode(selectStmt.GroupBy))
for _, v := range selectStmt.GroupBy {
switch colName := v.(type) {
case *sqlparser.ColName:
q.GroupBy = append(q.GroupBy, Field{
FromTable: colName.Qualifier.Name.String(),
Name: string(colName.Name.String()),
})
}
}
}
if selectStmt.OrderBy != nil {
// log.Println("selectStmt.OrderBy:", reflect.TypeOf(selectStmt.OrderBy), jsonEncode(selectStmt.OrderBy))
for _, v := range selectStmt.OrderBy {
switch colName := v.Expr.(type) {
case *sqlparser.ColName:
q.OrderBy = append(q.OrderBy, OrderByField{
Field: Field{
FromTable: colName.Qualifier.Name.String(),
Name: string(colName.Name.String()),
},
Direction: v.Direction,
})
}
}
}
if selectStmt.Having != nil {
// log.Println("selectStmt.Having:", reflect.TypeOf(selectStmt.Having), jsonEncode(selectStmt.Having))
q.Having = nodeExtractRawSQL(selectStmt.Having.Expr)
}
if selectStmt.Limit != nil {
// log.Println("selectStmt.Limit:", jsonEncode(selectStmt.Limit))
if q.Limit == nil {
q.Limit = &Limit{}
}
switch offset := selectStmt.Limit.Offset.(type) {
case *sqlparser.SQLVal:
q.Limit.Offset = cast.ToInt(fmt.Sprintf("%s", offset.Val))
}
switch count := selectStmt.Limit.Rowcount.(type) {
case *sqlparser.SQLVal:
q.Limit.Count = cast.ToInt(fmt.Sprintf("%s", count.Val))
}
}
// log.Println("query:", jsonEncode(q))
return nil
}
func (q *Query) FromSql(sql string) error {
stmt, err := sqlparser.Parse(sql)
if err != nil {
return err
}
switch stmt.(type) {
case *sqlparser.Select:
selectStmt := stmt.(*sqlparser.Select)
return q.FromSelect(selectStmt)
default:
return fmt.Errorf("unexpected type: %T", stmt)
}
}
type DialectType string
const (
POSTGRES DialectType = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
UNION = "union"
INTERSECT = "intersect"
EXCEPT = "except"
)
func (q *Query) ToBuilder() (*builder.Builder, error) {
var dialect DialectType = MYSQL
if q.Dialect != "" {
dialect = q.Dialect
}
fields := []string{}
for i := range q.Fields {
name := q.Fields[i].Name
if q.Fields[i].FromTable != "" {
name = fmt.Sprintf("%s.`%s`", q.Fields[i].FromTable, name)
}
if q.Fields[i].As != "" {
name = fmt.Sprintf("%s AS %s", name, q.Fields[i].As)
}
fields = append(fields, name)
}
b := builder.Dialect(string(dialect)).Select(fields...)
if q.From.raw != "" {
if q.From.As != "" {
b = b.From(fmt.Sprintf("(%s)", q.From.raw), q.From.As)
} else {
b = b.From(fmt.Sprintf("(%s)", q.From.raw))
}
} else if q.From.Table != "" {
if q.From.As != "" {
b = b.From(q.From.Table, q.From.As)
} else {
b = b.From(q.From.Table)
}
} else {
if q.From.SubQuery != nil {
subBuilder, err := q.From.SubQuery.ToBuilder()
if err != nil {
return nil, err
}
if q.From.As != "" {
b = b.From(subBuilder, q.From.As)
} else {
b = b.From(subBuilder)
}
}
}
if q.Join != nil {
for _, join := range q.Join {
if join.Query != nil {
// log.Println("join:", jsonEncode(join))
if join.Query.From.As == "" {
return nil, fmt.Errorf("join table must have an alias")
}
if join.Query.From.raw != "" {
cond, err := join.JoinWhere.ToCond()
if err != nil {
return nil, err
}
b = b.Join(join.Type, fmt.Sprintf("(%s)", join.Query.From.raw), cond, join.Query.From.As)
} else if join.Query.From.SubQuery != nil {
subBuilder, err := join.Query.From.SubQuery.ToBuilder()
if err != nil {
return nil, err
}
cond, err := join.JoinWhere.ToCond()
if err != nil {
return nil, err
}
b = b.Join(join.Type, subBuilder, cond, join.Query.From.As)
} else {
cond, err := join.JoinWhere.ToCond()
if err != nil {
return nil, err
}
b = b.Join(join.Type, join.Query.From.Table, cond, join.Query.From.As)
}
}
}
}
if q.GroupBy != nil {
groupBy := []string{}
for i := range q.GroupBy {
name := q.GroupBy[i].Name
if q.GroupBy[i].FromTable != "" {
name = fmt.Sprintf("%s.`%s`", q.GroupBy[i].FromTable, name)
}
groupBy = append(groupBy, name)
}
b = b.GroupBy(fmt.Sprintf("( %s )", strings.Join(groupBy, ", ")))
}
if q.OrderBy != nil {
orderBy := []string{}
for i := range q.OrderBy {
name := q.OrderBy[i].Field.Name
if q.OrderBy[i].Field.FromTable != "" {
name = fmt.Sprintf("%s.`%s`", q.OrderBy[i].Field.FromTable, name)
}
if q.OrderBy[i].Direction != "" {
name = fmt.Sprintf("%s %s", name, q.OrderBy[i].Direction)
}
orderBy = append(orderBy, name)
}
b = b.OrderBy(fmt.Sprintf("%s", strings.Join(orderBy, ", ")))
}
if q.Where != nil {
cond, err := q.Where.ToCond()
if err != nil {
return nil, err
}
b = b.Where(cond)
}
if q.Having != "" {
b = b.Having(q.Having)
}
if q.Limit != nil {
b = b.Limit(q.Limit.Count, q.Limit.Offset)
}
return b, nil
}
func (q *Query) ToSql() (string, error) {
b, err := q.ToBuilder()
if err != nil {
return "", fmt.Errorf("failed to build query: %w", err)
}
return b.ToBoundSQL()
}