掘金 后端 ( ) • 2024-06-28 15:25

theme: channing-cyan

前言

最近碰上有个业务,查询的sql如下:

select * from table where (sku_id,batch_no) in ((#{skuId},#{batchNo}),...);  

本来也没什么,很简单常见的一种sql。
问题是我们使用的是mybatis-plus,然后写的时候有没有考虑到后面的查询条件,这里用的是mybatis-plus lambda的方式。
示例如下:

LambdaQueryChainWrapper<Table> query = tableService.lambdaQuery();  
query.eq(Table::getId, param.getId());  

但是mysql-plus并没有支持这种sql的形式,要么用apply方法自定义拼接sql,要么不采用lambda方式,将语句写成xml形式。
不过,第一种方式又感觉很low,一大段java代码里插入一段sql字符串,看上去就很别扭,因为有点代码洁癖,只能果断放弃。
第二种改动又太大,前面那么多查询条件,又要全部移入xml里面,同时,本来已经通过测试的筛选,又要重新来一遍,太懒,实在干不动。
于是,又想想这么常见的场景,网上理应有解决方案。但不知道是搜索关键字不对还是确实没有,搜了半天没搜出来。
无奈,只有尝试自己扩展一下。

扩展

扩展第三方包功能,这种事老司机应该已经轻车熟路了,可以划走了。我分享下自己的思路,这里以mybatis-plus lambda为例,其他的任何第三方包其实都一样。
首先应该明白一点,一个较为成熟的第三方包,里面的代码逻辑不亚于我们任何一个系统的代码逻辑。
所以如果平时我们开发中,想要去理清五六成的代码逻辑再动手,那领导自然不是很高兴(费时费力)。
其实,最好最快的方式,就是看看我们需要使用的类或方法,在第三方包里面的他们的实现思路是什么。
这里,我想要的结果,就是LambdaQueryChainWrapper这个类下有一个方法支持(sku_id,batch_no) in ((#{skuId},#{batchNo}),...)这种条件查询。
那么,我们找一个与这个方法相近的in方法,看看这里面的实现是怎么样的。 点进来:

image.png 进的是mybtisplus包下面func的一个接口。
先看看in的实现类有哪些

image.png 看上去LambdaQueryChainWrapper并没有直接实现这个接口。别急,这种一看就是这两个类肯定是LambdaQueryChainWrapper的父类。我们不急着点进去,先记下这两个类。
为什么不急着点进去看看这两个类怎么实现的in方法呢?
因为大部分情况,这里点出来的实现类布置一两个,大部情况可能十几二十个。从上层一个找,那么可能几天都看不完,而现在要做的,是从下网上找。
我们找到LambdaQueryChainWrapper这个类

image.png 它在mybatisplus的扩展包里。这里很简单,LambdaQueryChainWrapper继承自AbstractChainWrapper,运气很好,是我们刚才看见的in实现类的其中一个类。
为什么说运气很好?
因为这里还会存在的情况是,LambdaQueryChainWrapper并不直接继承in方法的实现类,而是它的父类,甚至父类的父类,四五层的父类才实现的in的方法。这种情况就需要一层层往上找,找到上面实现in方法的具体继承下来的类。
好了,看一下AbstractChainWrapper类中in方法的实现

image.png 还是比较简单,就两行。我们关注下这行

getWrapper().in(condition, column, coll);

getWrapper()方法返回了一个AbstractWrapper对象。
可以先不用关注AbstractWrapper干了什么,我们感兴趣的是它执行了in方法。

image.png AbstractWrapper类中实现的in方法具体实现在doIt方法中

image.png 看上去也没什么需要关注的点

expression.add(sqlSegments);

看看add方法里面在干什么

image.png 嗯?已经在处理具体的sql片段了。从这里能看出来,大概意思orderBy、groupBy、having会加入特定集合,其他的会加入normal集合。
看到这就应该刹住车了,为什么?
因为既然这里没有看到将lambda处理成sql字符串的逻辑,而且很明显是准备做字符串拼接前的准备了。那么我们想要的将lambda方法处理成sql字符串的操作就是在前面执行的。
我们往前翻一翻 image.png 这里我们还漏了点东西。三个参数,condition不用去关注,这里一想就能明白,就是是否执行语句的标识。第二个参数是一个执行方法,看名字就能看出,是将字段转换成字符串的,点进去看看。

image.png 没什么复杂逻辑,再看看inExpression方法

image.png 嗯?看上去是不是有那么点意思了

image.png

image.png formatSql的具体实现,就是把{0}替换成集合中具体的值。然后在inExpression方法返回时拼接成(值1,值2...)的形式。
到这里我们就能看明白,inExpression方法其实就是拼接的id in (id1,id2...)的条件。 那么,我们再来看看in的方法 image.png 结构是不是就是 字段 in (查询条件),和我们实际的sql基本接近。
那么到这里其实就已经有了大概的实现思路了。

实现思路

思路其实也比较简单,我们可以按照AbstractWrapper类中in方法的实现,实现我们想要的效果。
首先,新增一个类,参考LambdaQueryChainWrapper类,继承自AbstractChainWrapper。
为什么没有继承LambdaQueryChainWrapper?

image.png 因为字段的泛型在AbstractWrapper上,我们也需要这个泛型,用于传参。
但为了保证select功能可用,我们需要把LambdaQueryChainWrapper中的方法全部copy一份到我们新的类上。然后我们来构建我们需要的方法:

image.png copy一份AbstractWrapper中in的实现,改巴改巴。我们需要传两个字段,同时需要两个集合参数。
为什么这里不用不定参数?
因为不定参只能在方法参数最后,这里用不定参就是字段和数据集合两个不定参,不好区分。
好了,现在的问题是,这里的几个方法在本类并没有被执行,就像图上所示红色的方法。
从原AbstractWrapper类中的实现我们可以看到,修饰这些方法的private和prodtected,但我们的实现类是在我们自己的包里的。所以并不能访问到原包下的方法。
每一个方法都copy一份到我们自己的包下?
当然这是不可取的,方法一个套一个,如果稍不注意就会遗漏或被修改,都容易出问题。
这里,比较常见的做法就是利用反射机制。
比如:

image.png 利用反射直接调用原实现类的方法,有一个小点是大部分方法是受保护的,所以需要执行ReflectionUtils.makeAccessible(method)
然后,我们来看下需要具体改造的点。
首先,字段需要拼接成(字段名1, 字段名2)的形式,这部分很简单,参考AbstractWrapper实现类,我们这样拼接

StringPool.LEFT_BRACKET + columnsToString(true, column1, column2) + StringPool.RIGHT_BRACKET

另一个点就是拼接后面的参数,需要改成((参数值1,参数值2),(参数值1,参数值2)...)这种方式
虽然会费点功夫,但也基本好搞,这里直接上完整代码

public class CustomerQueryChainWrapper<T> extends AbstractChainWrapper<T, SFunction<T, ?>, CustomerQueryChainWrapper<T>, LambdaQueryWrapper<T>>  
implements ChainQuery<T>, Query<CustomerQueryChainWrapper<T>, T, SFunction<T, ?>> {  
  
    private final BaseMapper<T> baseMapper;  
  
    public CustomerQueryChainWrapper(BaseMapper<T> baseMapper) {  
        super();  
        this.baseMapper = baseMapper;  
        super.wrapperChildren = new LambdaQueryWrapper<>();  
    }  
  
    @SafeVarargs  
    @Override  
    public final CustomerQueryChainWrapper<T> select(SFunction<T, ?>... columns) {  
        wrapperChildren.select(columns);  
        return typedThis;  
    }  
  
    @Override  
    public CustomerQueryChainWrapper<T> select(Class<T> entityClass, Predicate<TableFieldInfo> predicate) {  
        wrapperChildren.select(entityClass, predicate);  
        return typedThis;  
    }  
  
    @Override  
    public String getSqlSelect() {  
        throw ExceptionUtils.mpe("can not use this method for \"%s\"", "getSqlSelect");  
    }  
  
    @Override  
    public BaseMapper<T> getBaseMapper() {  
        return baseMapper;  
    }  
  
    public LambdaQueryWrapper<T> inOfParam2(SFunction<T, ?> column1, SFunction<T, ?> column2,  
                                    List<?> coll1, List<?> coll2) {  
        return doIt(true,  
            () -> StringPool.LEFT_BRACKET + columnsToString(true, column1, column2) + StringPool.RIGHT_BRACKET,  
            IN,  
            inExpressionOfParam(coll1, coll2));  
    }  

    public LambdaQueryWrapper<T> inOfParam3(SFunction<T, ?> column1, SFunction<T, ?> column2,  SFunction<T, ?> column3,  
                    List<?> coll1, List<?> coll2, List<?> coll3) {  
        return doIt(true,  
            () -> StringPool.LEFT_BRACKET + columnsToString(true, column1, column2, column3) + StringPool.RIGHT_BRACKET,  
            IN,  
            inExpressionOfParam(coll1, coll2, coll3));  
    }  
  
    /**  
    * 获取 columnName  
    */  
    protected String columnsToString(boolean onlyColumn, SFunction<T, ?>... columns) {  
        return Arrays.stream(columns).map(i -> columnToString(i, onlyColumn)).collect(joining(StringPool.COMMA));  
    }  
  
    protected String columnToString(SFunction<T, ?> column, boolean onlyColumn) {  
        Method method = ReflectUtil.getMethod(AbstractLambdaWrapper.class, "getColumn", SerializedLambda.class, Boolean.class);  
        try {  
            ReflectionUtils.makeAccessible(method);  
            return (String) method.invoke(wrapperChildren, LambdaUtils.resolve(column), onlyColumn);  
        } catch (IllegalAccessException e) {  
            throw new RuntimeException(e);  
        } catch (InvocationTargetException e) {  
            throw new RuntimeException(e);  
        }  
    }  
  
    private ISqlSegment inExpressionOfParam(List<?>... coll) {  
        int collSize = coll.length;  
        if (collSize == 0) {  
            return () -> "";  
        }  
        int size1 = coll[0].size();  
        if (size1 > 1000) {  
            size1 = 1000;  
        }  
        List<String> resultList = new ArrayList<>();  
        List<String> formatList = new ArrayList<>();  
        for (int i = 0; i<collSize; i++) {  
            formatList.add("{" + i + "}");  
        }  
        String format = formatList.stream().collect(joining(StringPool.COMMA, StringPool.LEFT_BRACKET, StringPool.RIGHT_BRACKET));  
        for (int i=0; i < size1; i++) {  
            Object[] objectArr = new Object[collSize];  
            for (int j = 0; j < collSize; j++) {  
                if (coll[j].size() >= i) {  
                    objectArr[j] = coll[j].get(i);  
                }  
            }  
            resultList.add(formatSql(format , objectArr));  
        }  
        return () -> resultList.stream()  
        .collect(joining(StringPool.COMMA, StringPool.LEFT_BRACKET, StringPool.RIGHT_BRACKET));  
    }  
  
    protected final String formatSql(String sqlStr, Object... params) {  
        Method method = ReflectUtil.getMethod(AbstractWrapper.class, "formatSql");  
        try {  
            ReflectionUtils.makeAccessible(method);  
            return (String) method.invoke(wrapperChildren, sqlStr, params);  
        } catch (IllegalAccessException e) {  
            throw new RuntimeException(e);  
        } catch (InvocationTargetException e) {  
            throw new RuntimeException(e);  
        }  
    }  
  
    protected LambdaQueryWrapper<T> doIt(boolean condition, ISqlSegment... sqlSegments) {  
        Method method = ReflectUtil.getMethod(AbstractWrapper.class, "doIt");;  
        try {  
            ReflectionUtils.makeAccessible(method);  
            return (LambdaQueryWrapper<T>) method.invoke(wrapperChildren, true, sqlSegments);  
        } catch (IllegalAccessException e) {  
            throw new RuntimeException(e);  
        } catch (InvocationTargetException e) {  
            throw new RuntimeException(e);  
        }  
    }  
}

拼接参数的主要逻辑在inExpressionOfParam里面
最后看一下使用

image.png 在具体需要用到的service里增加该方法,返回我们自己的实现类

image.png 直接使用即可

image.png 看一下sql执行日志,完美!

总结

今天接开发需求的一个例子,主要是像和大家分享平时遇上第三方包不支持的业务场景,如何快速的在其基础上扩展。
总结一下,首先明确目标,我们需要达到什么样的效果,其次,找到现有的实现类,模仿其中的写法。最后,利用反射机制,调用被保护的方法,达到我们的目的。