问题描述

在golang中使用mysql预处理查询时,一个"?"代表一个占位符,但是当查询条件中使用到in查询时,in的传值个数有时不是固定的,例如假设一个订单的状态如以下枚举:

type OrderStatus int32

const (
	All    OrderStatus = 0 //所有
	UnPaid OrderStatus = 1 //待付款
	Cancel OrderStatus = 2 //已取消
	Paid   OrderStatus = 3 //已付款
	Closed OrderStatus = 4 //已关闭
)

现在需要查询指定状态的订单记录,状态是由用户选择的,可能是为所有订单,也可能是只包含待付款、已取消、已关闭的订单,那么我们的sql该怎样构造呢?:
假设订单表的定义如下:

CREATE TABLE `tb_orders` (
  `id` int(11) NOT NULL AUTO_INCREMENT COMMENT '主键ID',
  `order_sn` varchar(16) NOT NULL COMMENT '订单编号',
  `order_status` tinyint(4) NOT NULL COMMENT '订单状态(1-待付款,2-已取消,3-已付款,4-已关闭)',
  `create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4

插入如下数据:

INSERT INTO `tb_orders` VALUES ('1', '201911120936', '1', '2019-11-11 21:36:30'), ('2', '201911100824', '2', '2019-11-11 21:36:53'), ('3', '201911090726', '1', '2019-11-11 21:41:06'), ('4', '201911100753', '3', '2019-11-11 21:41:55'), ('5', '201911230587', '4', '2019-11-11 21:42:10');

在这里插入图片描述

查询SQL可能这样:

select * from tb_orders where order_status in (?)

程序代码如下:

func main() {
	db, err := sql.Open("mysql", "...")
	if err != nil {
		panic(err)
	}
	sql := "select order_sn from tb_orders where order_status in(?)"
	stmt, err := db.Prepare(sql)
	if err != nil {
		panic(err)
	}
	defer stmt.Close()
	rows, err := stmt.Query("1,2,4") //查询待付款,已取消,已关闭
	defer rows.Close()
	for rows.Next() {
		var order_sn string
		if ers := rows.Scan(&order_sn); ers == nil {
			fmt.Println(order_sn)
		}
	}
}

当我们运行程序输出如下:

201911120936
201911090726

显然只查询到了状态为1的订单,但是我们程序中传入的参数是“1,2,4",分析后发现原因是:
预处理中?只是占位符,在处理时,将传入参数转换为字符串,因此数据库带入参数最终执行的SQL如下:

select * from tb_orders where order_status in ('1,2,4')

而不是想要的:

select * from tb_orders where order_status in (1,2,4)

正确的方法只有需要查询几个状态,则增加几个占位符。如:

select * from tb_orders where order_status in (?,?,?)

但是为了程序的灵活性,有没有办法可以根据传入的状态数量动态增加占位符呢?

解决方案

我们可以实现一个帮助类,根据传入的参数修改定义的SQL。
直接上代码:

package main

import (
	"fmt"
	"regexp"
	"strings"
)

type paramSqlPrepare struct {
	i    int
	sql  string
	args []string
	lens []int
}

func (p *paramSqlPrepare) replace(str string) string {
	var dstr []string
	l := 0
	for l < p.lens[p.i] {
		dstr = append(dstr, "?")
		l++
	}

	p.i++

	return strings.Join(dstr, ",")

}

func (p *paramSqlPrepare) prepare() (sql string, args []string) {

	for _, arg := range p.args {

		v_arr := strings.Split(arg, ",")
		l := len(v_arr)
		p.lens = append(p.lens, l)
		if l > 1 {
			for _, v := range v_arr {
				args = append(args, v)

			}
		} else {
			args = append(args, arg)
		}

	}

	rep, _ := regexp.Compile("\\?")
	sql = rep.ReplaceAllStringFunc(p.sql, p.replace)

	return sql, args

}

func main() {

	data := "select order_sn from tb_orders where order_status in(?)"
	psp := paramSqlPrepare{
		i:    0,
		sql:  data,
		args: []string{"1,2,3"},
	}
	sql, args := psp.prepare()
	fmt.Println(sql)
	fmt.Println(args)
}

传入参数:

select order_sn from tb_orders where order_status in(?)
"1,2,3"

输出:

select order_sn from tb_orders where order_status in(?,?,?)
[1 2 3]

实现原理:将传入的参数转换为数组,并计算出元素个数,利用正则匹配和函数替换将占位符替换为指定的元素个数,同时原参数全部转为一维数组。

github源码:https://github.com/suyaoli/example-001

Logo

欢迎加入西安开发者社区!我们致力于为西安地区的开发者提供学习、合作和成长的机会。参与我们的活动,与专家分享最新技术趋势,解决挑战,探索创新。加入我们,共同打造技术社区!

更多推荐