diff --git a/apijson/query_context.go b/apijson/query_context.go new file mode 100644 index 0000000..df57dce --- /dev/null +++ b/apijson/query_context.go @@ -0,0 +1,113 @@ +package apijson + +import ( + "fmt" + "net/http" + "strings" + + "github.com/micro/go-micro/logger" +) + +type QueryContext struct { + req map[string]interface{} + code int + nodeTree map[string]*QueryNode + nodePathMap map[string]*QueryNode + err error + explain bool +} + +func NewQueryContext(req map[string]interface{}) *QueryContext { + return &QueryContext{ + req: req, + code: http.StatusOK, + nodeTree: make(map[string]*QueryNode), + nodePathMap: make(map[string]*QueryNode), + } +} + +func (c *QueryContext) Response() map[string]interface{} { + c.doParse() + if c.err == nil { + c.doQuery() + } + resultMap := make(map[string]interface{}) + resultMap["ok"] = c.code == http.StatusOK + resultMap["code"] = c.code + if c.err != nil { + resultMap["msg"] = c.err.Error() + } else { + for k, v := range c.nodeTree { + //logger.Debugf("response.nodeMap K: %s, V: %v", k, v) + resultMap[k] = v.Result() + } + } + return resultMap +} + +func (c *QueryContext) doParse() { + //startTime := time.Now().Nanosecond() + for key := range c.req { + if c.err != nil { + return + } + if key == "@explain" { + c.explain = c.req[key].(bool) + } else if c.nodeTree[key] == nil { + c.parseByKey(key) + } + } +} + +func (c *QueryContext) doQuery() { + for _, n := range c.nodeTree { + if c.err != nil { + return + } + n.doQueryData() + } +} + +func (c *QueryContext) parseByKey(key string) { + queryObject := c.req[key] + if queryObject == nil { + c.err = fmt.Errorf("值不能为空, key: %s, value: %v", key, queryObject) + return + } + if queryMap, ok := queryObject.(map[string]interface{}); !ok { + c.err = fmt.Errorf("值类型不对, key: %s, value: %v", key, queryObject) + } else { + node := NewQueryNode(c, key, key, queryMap) + logger.Debugf("parse %s: %+v", key, node) + c.nodeTree[key] = node + } +} + +func (c *QueryContext) End(code int, msg string) { + c.code = code + logger.Errorf("发生错误,终止处理, code: %d, msg: %s", code, msg) +} + +func (c *QueryContext) findResult(value string) interface{} { + i := strings.LastIndex(value, "/") + path := value[0:i] + node := c.nodePathMap[path] + if node == nil { + c.err = fmt.Errorf("关联查询参数有误: %s", value) + return nil + } + if node.running { + c.err = fmt.Errorf("有循环依赖") + return nil + } + node.doQueryData() + if c.err != nil { + return nil + } + if node.CurrentData == nil { + logger.Info("查询结果为空,queryPath: " + value) + return nil + } + key := value[i+1:] + return node.CurrentData[key] +} diff --git a/apijson/query_node.go b/apijson/query_node.go new file mode 100644 index 0000000..be1065c --- /dev/null +++ b/apijson/query_node.go @@ -0,0 +1,198 @@ +package apijson + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/micro/go-micro/logger" +) + +func NewQueryNode(c *QueryContext, path, key string, queryMap map[string]interface{}) *QueryNode { + n := &QueryNode{ + ctx: c, + Key: strings.ToLower(key), + Path: path, + RequestMap: queryMap, + start: time.Now().UnixNano(), + sqlExecutor: &MysqlExecutor{}, + isList: strings.HasSuffix(key, "[]"), + } + c.nodePathMap[path] = n + if n.isList { + n.parseList() + } else { + n.parseOne() + } + return n +} + +type QueryNode struct { + ctx *QueryContext + start int64 + depth int8 + running bool + completed bool + isList bool + page interface{} + count interface{} + + sqlExecutor *MysqlExecutor + primaryKey string + relateKV map[string]string + + Key string + Path string + RequestMap map[string]interface{} + CurrentData map[string]interface{} + ResultList []map[string]interface{} + children map[string]*QueryNode +} + +func (n *QueryNode) parseList() { + root := n.ctx + if root.err != nil { + return + } + if value, exists := n.RequestMap[n.Key[0:len(n.Key)-2]]; exists { + if kvs, ok := value.(map[string]interface{}); ok { + root.err = n.sqlExecutor.ParseTable(n.Key) + n.parseKVs(kvs) + } else { + root.err = fmt.Errorf("列表同名参数展开出错,listKey: %s, object: %v", n.Key, value) + root.code = http.StatusBadRequest + } + return + } + for field, value := range n.RequestMap { + if value == nil { + root.err = fmt.Errorf("field of [%s] value error, %s is nil", n.Key, field) + return + } + switch field { + case "page": + n.page = value + case "count": + n.count = value + default: + if kvs, ok := value.(map[string]interface{}); ok { + child := NewQueryNode(root, n.Path+"/"+field, field, kvs) + if root.err != nil { + return + } + if n.children == nil { + n.children = make(map[string]*QueryNode) + } + n.children[field] = child + if nonDepend(n, child) && len(n.primaryKey) == 0 { + n.primaryKey = field + } + } + } + } +} + +func nonDepend(parent, child *QueryNode) bool { + if len(child.relateKV) == 0 { + return true + } + for _, v := range child.relateKV { + if strings.HasPrefix(v, parent.Path) { + return false + } + } + return true +} + +func (n *QueryNode) parseOne() { + root := n.ctx + root.err = n.sqlExecutor.ParseTable(n.Key) + if root.err != nil { + root.code = http.StatusBadRequest + return + } + n.sqlExecutor.PageSize(0, 1) + n.parseKVs(n.RequestMap) +} + +func (n *QueryNode) parseKVs(kvs map[string]interface{}) { + root := n.ctx + for field, value := range kvs { + logger.Debugf("%s -> parse %s %v", n.Key, field, value) + if value == nil { + root.err = fmt.Errorf("field value error, %s is nil", field) + root.code = http.StatusBadRequest + return + } + if queryPath, ok := value.(string); ok && strings.HasSuffix(field, "@") { // @ 结尾表示有关联查询 + if n.relateKV == nil { + n.relateKV = make(map[string]string) + } + fullPath := queryPath + if strings.HasPrefix(queryPath, "/") { + fullPath = n.Path + queryPath + } + n.relateKV[field[0:len(field)-1]] = fullPath + } else { + n.sqlExecutor.ParseCondition(field, value) + } + } +} + +func (n *QueryNode) Result() interface{} { + if n.isList { + return n.ResultList + } + if len(n.ResultList) > 0 { + return n.ResultList[0] + } + return nil +} + +func (n *QueryNode) doQueryData() { + if n.completed { + return + } + n.running = true + defer func() { n.running, n.completed = false, true }() + root := n.ctx + if len(n.relateKV) > 0 { + for field, queryPath := range n.relateKV { + value := root.findResult(queryPath) + if root.err != nil { + return + } + n.sqlExecutor.ParseCondition(field, value) + } + } + if !n.isList { + n.ResultList, root.err = n.sqlExecutor.Exec() + if len(n.ResultList) > 0 { + n.CurrentData = n.ResultList[0] + return + } + return + } + primary := n.children[n.primaryKey] + primary.sqlExecutor.PageSize(n.page, n.count) + primary.doQueryData() + if root.err != nil { + return + } + listData := primary.ResultList + n.ResultList = make([]map[string]interface{}, len(listData)) + for i, x := range listData { + n.ResultList[i] = make(map[string]interface{}) + n.ResultList[i][n.primaryKey] = x + primary.CurrentData = x + if len(n.children) > 0 { + for _, child := range n.children { + if child != primary { + child.doQueryData() + n.ResultList[i][child.Key] = child.Result() + } + } + } + } +} diff --git a/apijson/sqlparser.go b/apijson/sqlparser.go new file mode 100644 index 0000000..ad2ba96 --- /dev/null +++ b/apijson/sqlparser.go @@ -0,0 +1,118 @@ +package apijson + +import ( + "bytes" + "strconv" + "strings" +) + +const DefaultLimit = 1000 + +type MysqlExecutor struct { + table string + columns []string + where []string + params []interface{} + order string + group string + limit int + page int +} + +func (e *MysqlExecutor) Table() string { + return e.table +} + +func (e *MysqlExecutor) ParseTable(t string) error { + if strings.HasSuffix(t, "[]") { + t = t[0 : len(t)-2] + } + e.table = t + return nil +} + +func (e *MysqlExecutor) ToSQL() string { + var buf bytes.Buffer + buf.WriteString("SELECT ") + if e.columns == nil { + buf.WriteString("*") + } else { + buf.WriteString(strings.Join(e.columns, ",")) + } + buf.WriteString(" FROM ") + buf.WriteString(e.table) + if len(e.where) > 0 { + buf.WriteString(" WHERE ") + buf.WriteString(strings.Join(e.where, " and ")) + } + if e.order != "" { + buf.WriteString(" ORDER BY ") + buf.WriteString(e.order) + } + buf.WriteString(" LIMIT ") + buf.WriteString(strconv.Itoa(e.limit)) + if e.limit > 1 { + buf.WriteString(" OFFSET ") + buf.WriteString(strconv.Itoa(e.limit * e.page)) + } + return buf.String() +} + +func (e *MysqlExecutor) ParseCondition(field string, value interface{}) { + if values, ok := value.([]interface{}); ok { + // 数组使用 IN 条件 + condition := field + " in (" + for i, v := range values { + if i == 0 { + condition += "?" + } else { + condition += ",?" + } + e.params = append(e.params, v) + } + e.where = append(e.where, condition+")") + } else if valueStr, ok := value.(string); ok { + if strings.HasPrefix(field, "@") { + switch field[1:] { + case "order": + e.order = valueStr + case "column": + e.columns = strings.Split(valueStr, ",") + } + } else { + e.where = append(e.where, field+"=?") + e.params = append(e.params, valueStr) + } + } else { + e.where = append(e.where, field+"=?") + e.params = append(e.params, value) + } +} + +func (e *MysqlExecutor) Exec() ([]map[string]interface{}, error) { + sql := e.ToSQL() + return QueryAll(sql, e.params...) +} + +var QueryAll = func(sql string, args ...interface{}) ([]map[string]interface{}, error) { + return nil, nil +} + +func SetQueryAll(f func(sql string, args ...interface{}) ([]map[string]interface{}, error)) { + QueryAll = f +} + +func (e *MysqlExecutor) PageSize(page interface{}, count interface{}) { + e.page = parseNum(page, 0) + e.limit = parseNum(count, 10) +} + +func parseNum(value interface{}, defaultVal int) int { + if n, ok := value.(float64); ok { + return int(n) + } + if n, ok := value.(int); ok { + return n + } + return defaultVal +} diff --git a/query_node_test.go b/query_node_test.go new file mode 100644 index 0000000..d94d771 --- /dev/null +++ b/query_node_test.go @@ -0,0 +1,29 @@ +package vql + +import ( + "encoding/json" + "log" + "testing" + + "git.ouxuan.net/3136352472/vql/apijson" +) + +func TestNewQueryNode(t *testing.T) { + apijson.SetQueryAll(func(sql string, args ...interface{}) ([]map[string]interface{}, error) { + log.Println(sql, args) + sqlx, err := Analyze(sql) + if err != nil { + return nil, err + } + log.Println(sqlx.ToSql()) + return nil, nil + }) + + reqStr := `{ "Moment": { "id":12 } }` + req := make(map[string]interface{}) + json.Unmarshal([]byte(reqStr), &req) + + ctx := apijson.NewQueryContext(req) + + log.Println(ctx.Response()) +} diff --git a/vql.go b/vql.go index e3ae627..9ec0b37 100644 --- a/vql.go +++ b/vql.go @@ -169,3 +169,7 @@ func Analyze(sql string) (*Query, error) { err := q.FromSql(sql) return &q, err } + +// func (vql *VirtualQL) CompileWithApijson(sql apijson.QueryNode, param map[string]interface{}) (*Query, error) { + +// }