初始化提交

This commit is contained in:
YangJ
2024-03-20 09:42:17 +08:00
commit 72f30209cf
3705 changed files with 285827 additions and 0 deletions

47
yudao-framework/pom.xml Normal file
View File

@ -0,0 +1,47 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<artifactId>yudao</artifactId>
<groupId>cn.iocoder.boot</groupId>
<version>${revision}</version>
</parent>
<packaging>pom</packaging>
<modules>
<module>yudao-common</module>
<module>yudao-spring-boot-starter-mybatis</module>
<module>yudao-spring-boot-starter-redis</module>
<module>yudao-spring-boot-starter-web</module>
<module>yudao-spring-boot-starter-security</module>
<module>yudao-spring-boot-starter-websocket</module>
<module>yudao-spring-boot-starter-monitor</module>
<module>yudao-spring-boot-starter-protection</module>
<module>yudao-spring-boot-starter-job</module>
<module>yudao-spring-boot-starter-mq</module>
<module>yudao-spring-boot-starter-excel</module>
<module>yudao-spring-boot-starter-test</module>
<module>yudao-spring-boot-starter-biz-operatelog</module>
<module>yudao-spring-boot-starter-biz-tenant</module>
<module>yudao-spring-boot-starter-biz-data-permission</module>
<module>yudao-spring-boot-starter-biz-ip</module>
</modules>
<artifactId>yudao-framework</artifactId>
<description>
该包是技术组件,每个子包,代表一个组件。每个组件包括两部分:
1. core 包:是该组件的核心封装
2. config 包:是该组件基于 Spring 的配置
技术组件,也分成两类:
1. 框架组件:和我们熟悉的 MyBatis、Redis 等等的拓展
2. 业务组件:和业务相关的组件的封装,例如说数据字典、操作日志等等。
如果是业务组件Maven 名字会包含 biz
</description>
<url>https://github.com/YunaiV/ruoyi-vue-pro</url>
</project>

View File

@ -0,0 +1,144 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-framework</artifactId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>yudao-common</artifactId>
<packaging>jar</packaging>
<name>${project.artifactId}</name>
<description>定义基础 pojo 类、枚举、工具类等等</description>
<url>https://github.com/YunaiV/ruoyi-vue-pro</url>
<dependencies>
<!-- Spring 核心 -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-core</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-expression</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-aop</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjweaver</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<!-- 用于生成自定义的 Spring @ConfigurationProperties 配置类的说明文件 -->
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
<optional>true</optional>
</dependency>
<!-- Web 相关 -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>jakarta.servlet</groupId>
<artifactId>jakarta.servlet-api</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>org.springdoc</groupId>
<artifactId>springdoc-openapi-ui</artifactId>
<scope>provided</scope> <!-- 设置为 provided主要是 PageParam 使用到 -->
</dependency>
<!-- 监控相关 -->
<dependency>
<groupId>org.apache.skywalking</groupId>
<artifactId>apm-toolkit-trace</artifactId>
</dependency>
<!-- 工具类相关 -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct</artifactId>
</dependency>
<dependency>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct-jdk8</artifactId> <!-- use mapstruct-jdk8 for Java 8 or higher -->
</dependency>
<dependency>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct-processor</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<dependency>
<groupId>jakarta.validation</groupId>
<artifactId>jakarta.validation-api</artifactId>
<scope>provided</scope> <!-- 设置为 provided主要是 PageParam 使用到 -->
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>transmittable-thread-local</artifactId>
</dependency>
<!-- Test 测试相关 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,15 @@
package cn.iocoder.yudao.framework.common.core;
/**
* 可生成 Int 数组的接口
*
* @author 芋道源码
*/
public interface IntArrayValuable {
/**
* @return int 数组
*/
int[] array();
}

View File

@ -0,0 +1,22 @@
package cn.iocoder.yudao.framework.common.core;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
/**
* Key Value 的键值对
*
* @author 芋道源码
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class KeyValue<K, V> implements Serializable {
private K key;
private V value;
}

View File

@ -0,0 +1,46 @@
package cn.iocoder.yudao.framework.common.enums;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* 通用状态枚举
*
* @author 芋道源码
*/
@Getter
@AllArgsConstructor
public enum CommonStatusEnum implements IntArrayValuable {
ENABLE(0, "开启"),
DISABLE(1, "关闭");
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(CommonStatusEnum::getStatus).toArray();
/**
* 状态值
*/
private final Integer status;
/**
* 状态名
*/
private final String name;
@Override
public int[] array() {
return ARRAYS;
}
public static boolean isEnable(Integer status) {
return ObjUtil.equal(ENABLE.status, status);
}
public static boolean isDisable(Integer status) {
return ObjUtil.equal(DISABLE.status, status);
}
}

View File

@ -0,0 +1,21 @@
package cn.iocoder.yudao.framework.common.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 文档地址
*
* @author 芋道源码
*/
@Getter
@AllArgsConstructor
public enum DocumentEnum {
REDIS_INSTALL("https://gitee.com/zhijiantianya/ruoyi-vue-pro/issues/I4VCSJ", "Redis 安装文档"),
TENANT("https://doc.iocoder.cn", "SaaS 多租户文档");
private final String url;
private final String memo;
}

View File

@ -0,0 +1,40 @@
package cn.iocoder.yudao.framework.common.enums;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import java.util.Arrays;
/**
* 终端的枚举
*
* @author 芋道源码
*/
@RequiredArgsConstructor
@Getter
public enum TerminalEnum implements IntArrayValuable {
UNKNOWN(0, "未知"), // 目的:在无法解析到 terminal 时,使用它
WECHAT_MINI_PROGRAM(10, "微信小程序"),
WECHAT_WAP(11, "微信公众号"),
H5(20, "H5 网页"),
APP(31, "手机 App"),
;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(TerminalEnum::getTerminal).toArray();
/**
* 终端
*/
private final Integer terminal;
/**
* 终端名
*/
private final String name;
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -0,0 +1,39 @@
package cn.iocoder.yudao.framework.common.enums;
import cn.hutool.core.util.ArrayUtil;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* 全局用户类型枚举
*/
@AllArgsConstructor
@Getter
public enum UserTypeEnum implements IntArrayValuable {
MEMBER(1, "会员"), // 面向 c 端,普通用户
ADMIN(2, "管理员"); // 面向 b 端,管理后台
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(UserTypeEnum::getValue).toArray();
/**
* 类型
*/
private final Integer value;
/**
* 类型名
*/
private final String name;
public static UserTypeEnum valueOf(Integer value) {
return ArrayUtil.firstMatch(userType -> userType.getValue().equals(value), UserTypeEnum.values());
}
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -0,0 +1,34 @@
package cn.iocoder.yudao.framework.common.enums;
/**
* Web 过滤器顺序的枚举类,保证过滤器按照符合我们的预期
*
* 考虑到每个 starter 都需要用到该工具类,所以放到 common 模块下的 enums 包下
*
* @author 芋道源码
*/
public interface WebFilterOrderEnum {
int CORS_FILTER = Integer.MIN_VALUE;
int TRACE_FILTER = CORS_FILTER + 1;
int REQUEST_BODY_CACHE_FILTER = Integer.MIN_VALUE + 500;
// OrderedRequestContextFilter 默认为 -105用于国际化上下文等等
int TENANT_CONTEXT_FILTER = - 104; // 需要保证在 ApiAccessLogFilter 前面
int API_ACCESS_LOG_FILTER = -103; // 需要保证在 RequestBodyCacheFilter 后面
int XSS_FILTER = -102; // 需要保证在 RequestBodyCacheFilter 后面
// Spring Security Filter 默认为 -100可见 org.springframework.boot.autoconfigure.security.SecurityProperties 配置属性类
int TENANT_SECURITY_FILTER = -99; // 需要保证在 Spring Security 过滤器后面
int FLOWABLE_FILTER = -98; // 需要保证在 Spring Security 过滤后面
int DEMO_FILTER = Integer.MAX_VALUE;
}

View File

@ -0,0 +1,32 @@
package cn.iocoder.yudao.framework.common.exception;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import cn.iocoder.yudao.framework.common.exception.enums.ServiceErrorCodeRange;
import lombok.Data;
/**
* 错误码对象
*
* 全局错误码,占用 [0, 999], 参见 {@link GlobalErrorCodeConstants}
* 业务异常错误码,占用 [1 000 000 000, +∞),参见 {@link ServiceErrorCodeRange}
*
* TODO 错误码设计成对象的原因,为未来的 i18 国际化做准备
*/
@Data
public class ErrorCode {
/**
* 错误码
*/
private final Integer code;
/**
* 错误提示
*/
private final String msg;
public ErrorCode(Integer code, String message) {
this.code = code;
this.msg = message;
}
}

View File

@ -0,0 +1,60 @@
package cn.iocoder.yudao.framework.common.exception;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import lombok.Data;
import lombok.EqualsAndHashCode;
/**
* 服务器异常 Exception
*/
@Data
@EqualsAndHashCode(callSuper = true)
public final class ServerException extends RuntimeException {
/**
* 全局错误码
*
* @see GlobalErrorCodeConstants
*/
private Integer code;
/**
* 错误提示
*/
private String message;
/**
* 空构造方法,避免反序列化问题
*/
public ServerException() {
}
public ServerException(ErrorCode errorCode) {
this.code = errorCode.getCode();
this.message = errorCode.getMsg();
}
public ServerException(Integer code, String message) {
this.code = code;
this.message = message;
}
public Integer getCode() {
return code;
}
public ServerException setCode(Integer code) {
this.code = code;
return this;
}
@Override
public String getMessage() {
return message;
}
public ServerException setMessage(String message) {
this.message = message;
return this;
}
}

View File

@ -0,0 +1,60 @@
package cn.iocoder.yudao.framework.common.exception;
import cn.iocoder.yudao.framework.common.exception.enums.ServiceErrorCodeRange;
import lombok.Data;
import lombok.EqualsAndHashCode;
/**
* 业务逻辑异常 Exception
*/
@Data
@EqualsAndHashCode(callSuper = true)
public final class ServiceException extends RuntimeException {
/**
* 业务错误码
*
* @see ServiceErrorCodeRange
*/
private Integer code;
/**
* 错误提示
*/
private String message;
/**
* 空构造方法,避免反序列化问题
*/
public ServiceException() {
}
public ServiceException(ErrorCode errorCode) {
this.code = errorCode.getCode();
this.message = errorCode.getMsg();
}
public ServiceException(Integer code, String message) {
this.code = code;
this.message = message;
}
public Integer getCode() {
return code;
}
public ServiceException setCode(Integer code) {
this.code = code;
return this;
}
@Override
public String getMessage() {
return message;
}
public ServiceException setMessage(String message) {
this.message = message;
return this;
}
}

View File

@ -0,0 +1,41 @@
package cn.iocoder.yudao.framework.common.exception.enums;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
/**
* 全局错误码枚举
* 0-999 系统异常编码保留
*
* 一般情况下,使用 HTTP 响应状态码 https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Status
* 虽然说HTTP 响应状态码作为业务使用表达能力偏弱,但是使用在系统层面还是非常不错的
* 比较特殊的是,因为之前一直使用 0 作为成功,就不使用 200 啦。
*
* @author 芋道源码
*/
public interface GlobalErrorCodeConstants {
ErrorCode SUCCESS = new ErrorCode(0, "成功");
// ========== 客户端错误段 ==========
ErrorCode BAD_REQUEST = new ErrorCode(400, "请求参数不正确");
ErrorCode UNAUTHORIZED = new ErrorCode(401, "账号未登录");
ErrorCode FORBIDDEN = new ErrorCode(403, "没有该操作权限");
ErrorCode NOT_FOUND = new ErrorCode(404, "请求未找到");
ErrorCode METHOD_NOT_ALLOWED = new ErrorCode(405, "请求方法不正确");
ErrorCode LOCKED = new ErrorCode(423, "请求失败,请稍后重试"); // 并发请求,不允许
ErrorCode TOO_MANY_REQUESTS = new ErrorCode(429, "请求过于频繁,请稍后重试");
// ========== 服务端错误段 ==========
ErrorCode INTERNAL_SERVER_ERROR = new ErrorCode(500, "系统异常");
ErrorCode NOT_IMPLEMENTED = new ErrorCode(501, "功能未实现/未开启");
ErrorCode ERROR_CONFIGURATION = new ErrorCode(502, "错误的配置项");
// ========== 自定义错误段 ==========
ErrorCode REPEATED_REQUESTS = new ErrorCode(900, "重复请求,请稍后重试"); // 重复请求
ErrorCode DEMO_DENY = new ErrorCode(901, "演示模式,禁止写操作");
ErrorCode UNKNOWN = new ErrorCode(999, "未知错误");
}

View File

@ -0,0 +1,46 @@
package cn.iocoder.yudao.framework.common.exception.enums;
/**
* 业务异常的错误码区间,解决:解决各模块错误码定义,避免重复,在此只声明不做实际使用
*
* 一共 10 位,分成四段
*
* 第一段1 位,类型
* 1 - 业务级别异常
* x - 预留
* 第二段3 位,系统类型
* 001 - 用户系统
* 002 - 商品系统
* 003 - 订单系统
* 004 - 支付系统
* 005 - 优惠劵系统
* ... - ...
* 第三段3 位,模块
* 不限制规则。
* 一般建议,每个系统里面,可能有多个模块,可以再去做分段。以用户系统为例子:
* 001 - OAuth2 模块
* 002 - User 模块
* 003 - MobileCode 模块
* 第四段3 位,错误码
* 不限制规则。
* 一般建议,每个模块自增。
*
* @author 芋道源码
*/
public class ServiceErrorCodeRange {
// 模块 infra 错误码区间 [1-001-000-000 ~ 1-002-000-000)
// 模块 system 错误码区间 [1-002-000-000 ~ 1-003-000-000)
// 模块 report 错误码区间 [1-003-000-000 ~ 1-004-000-000)
// 模块 member 错误码区间 [1-004-000-000 ~ 1-005-000-000)
// 模块 mp 错误码区间 [1-006-000-000 ~ 1-007-000-000)
// 模块 pay 错误码区间 [1-007-000-000 ~ 1-008-000-000)
// 模块 bpm 错误码区间 [1-009-000-000 ~ 1-010-000-000)
// 模块 product 错误码区间 [1-008-000-000 ~ 1-009-000-000)
// 模块 trade 错误码区间 [1-011-000-000 ~ 1-012-000-000)
// 模块 promotion 错误码区间 [1-013-000-000 ~ 1-014-000-000)
// 模块 crm 错误码区间 [1-020-000-000 ~ 1-021-000-000)
}

View File

@ -0,0 +1,127 @@
package cn.iocoder.yudao.framework.common.exception.util;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.exception.ServiceException;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
/**
* {@link ServiceException} 工具类
*
* 目的在于,格式化异常信息提示。
* 考虑到 String.format 在参数不正确时会报错,因此使用 {} 作为占位符,并使用 {@link #doFormat(int, String, Object...)} 方法来格式化
*
* 因为 {@link #MESSAGES} 里面默认是没有异常信息提示的模板的,所以需要使用方自己初始化进去。目前想到的有几种方式:
*
* 1. 异常提示信息写在枚举类中例如说cn.iocoder.oceans.user.api.constants.ErrorCodeEnum 类 + ServiceExceptionConfiguration
* 2. 异常提示信息,写在 .properties 等等配置文件
* 3. 异常提示信息,写在 Apollo 等等配置中心中,从而实现可动态刷新
* 4. 异常提示信息,存储在 db 等等数据库中,从而实现可动态刷新
*/
@Slf4j
public class ServiceExceptionUtil {
/**
* 错误码提示模板
*/
private static final ConcurrentMap<Integer, String> MESSAGES = new ConcurrentHashMap<>();
public static void putAll(Map<Integer, String> messages) {
ServiceExceptionUtil.MESSAGES.putAll(messages);
}
public static void put(Integer code, String message) {
ServiceExceptionUtil.MESSAGES.put(code, message);
}
public static void delete(Integer code, String message) {
ServiceExceptionUtil.MESSAGES.remove(code, message);
}
// ========== 和 ServiceException 的集成 ==========
public static ServiceException exception(ErrorCode errorCode) {
String messagePattern = MESSAGES.getOrDefault(errorCode.getCode(), errorCode.getMsg());
return exception0(errorCode.getCode(), messagePattern);
}
public static ServiceException exception(ErrorCode errorCode, Object... params) {
String messagePattern = MESSAGES.getOrDefault(errorCode.getCode(), errorCode.getMsg());
return exception0(errorCode.getCode(), messagePattern, params);
}
/**
* 创建指定编号的 ServiceException 的异常
*
* @param code 编号
* @return 异常
*/
public static ServiceException exception(Integer code) {
return exception0(code, MESSAGES.get(code));
}
/**
* 创建指定编号的 ServiceException 的异常
*
* @param code 编号
* @param params 消息提示的占位符对应的参数
* @return 异常
*/
public static ServiceException exception(Integer code, Object... params) {
return exception0(code, MESSAGES.get(code), params);
}
public static ServiceException exception0(Integer code, String messagePattern, Object... params) {
String message = doFormat(code, messagePattern, params);
return new ServiceException(code, message);
}
public static ServiceException invalidParamException(String messagePattern, Object... params) {
return exception0(GlobalErrorCodeConstants.BAD_REQUEST.getCode(), messagePattern, params);
}
// ========== 格式化方法 ==========
/**
* 将错误编号对应的消息使用 params 进行格式化。
*
* @param code 错误编号
* @param messagePattern 消息模版
* @param params 参数
* @return 格式化后的提示
*/
@VisibleForTesting
public static String doFormat(int code, String messagePattern, Object... params) {
StringBuilder sbuf = new StringBuilder(messagePattern.length() + 50);
int i = 0;
int j;
int l;
for (l = 0; l < params.length; l++) {
j = messagePattern.indexOf("{}", i);
if (j == -1) {
log.error("[doFormat][参数过多:错误码({})|错误内容({})|参数({})", code, messagePattern, params);
if (i == 0) {
return messagePattern;
} else {
sbuf.append(messagePattern.substring(i));
return sbuf.toString();
}
} else {
sbuf.append(messagePattern, i, j);
sbuf.append(params[l]);
i = j + 2;
}
}
if (messagePattern.indexOf("{}", i) != -1) {
log.error("[doFormat][参数过少:错误码({})|错误内容({})|参数({})", code, messagePattern, params);
}
sbuf.append(messagePattern.substring(i));
return sbuf.toString();
}
}

View File

@ -0,0 +1,6 @@
/**
* 基础的通用类,和框架无关
*
* 例如说CommonResult 为通用返回
*/
package cn.iocoder.yudao.framework.common;

View File

@ -0,0 +1,112 @@
package cn.iocoder.yudao.framework.common.pojo;
import cn.iocoder.yudao.framework.common.exception.ErrorCode;
import cn.iocoder.yudao.framework.common.exception.ServiceException;
import cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data;
import org.springframework.util.Assert;
import java.io.Serializable;
import java.util.Objects;
/**
* 通用返回
*
* @param <T> 数据泛型
*/
@Data
public class CommonResult<T> implements Serializable {
/**
* 错误码
*
* @see ErrorCode#getCode()
*/
private Integer code;
/**
* 返回数据
*/
private T data;
/**
* 错误提示,用户可阅读
*
* @see ErrorCode#getMsg() ()
*/
private String msg;
/**
* 将传入的 result 对象,转换成另外一个泛型结果的对象
*
* 因为 A 方法返回的 CommonResult 对象,不满足调用其的 B 方法的返回,所以需要进行转换。
*
* @param result 传入的 result 对象
* @param <T> 返回的泛型
* @return 新的 CommonResult 对象
*/
public static <T> CommonResult<T> error(CommonResult<?> result) {
return error(result.getCode(), result.getMsg());
}
public static <T> CommonResult<T> error(Integer code, String message) {
Assert.isTrue(!GlobalErrorCodeConstants.SUCCESS.getCode().equals(code), "code 必须是错误的!");
CommonResult<T> result = new CommonResult<>();
result.code = code;
result.msg = message;
return result;
}
public static <T> CommonResult<T> error(ErrorCode errorCode) {
return error(errorCode.getCode(), errorCode.getMsg());
}
public static <T> CommonResult<T> success(T data) {
CommonResult<T> result = new CommonResult<>();
result.code = GlobalErrorCodeConstants.SUCCESS.getCode();
result.data = data;
result.msg = "";
return result;
}
public static boolean isSuccess(Integer code) {
return Objects.equals(code, GlobalErrorCodeConstants.SUCCESS.getCode());
}
@JsonIgnore // 避免 jackson 序列化
public boolean isSuccess() {
return isSuccess(code);
}
@JsonIgnore // 避免 jackson 序列化
public boolean isError() {
return !isSuccess();
}
// ========= 和 Exception 异常体系集成 =========
/**
* 判断是否有异常。如果有,则抛出 {@link ServiceException} 异常
*/
public void checkError() throws ServiceException {
if (isSuccess()) {
return;
}
// 业务异常
throw new ServiceException(code, msg);
}
/**
* 判断是否有异常。如果有,则抛出 {@link ServiceException} 异常
* 如果没有,则返回 {@link #data} 数据
*/
@JsonIgnore // 避免 jackson 序列化
public T getCheckedData() {
checkError();
return data;
}
public static <T> CommonResult<T> error(ServiceException serviceException) {
return error(serviceException.getCode(), serviceException.getMessage());
}
}

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.framework.common.pojo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import javax.validation.constraints.Min;
import javax.validation.constraints.Max;
import javax.validation.constraints.NotNull;
import java.io.Serializable;
@Schema(description="分页参数")
@Data
public class PageParam implements Serializable {
private static final Integer PAGE_NO = 1;
private static final Integer PAGE_SIZE = 10;
/**
* 每页条数 - 不分页
*
* 例如说,导出接口,可以设置 {@link #pageSize} 为 -1 不分页,查询所有数据。
*/
public static final Integer PAGE_SIZE_NONE = -1;
@Schema(description = "页码,从 1 开始", requiredMode = Schema.RequiredMode.REQUIRED,example = "1")
@NotNull(message = "页码不能为空")
@Min(value = 1, message = "页码最小值为 1")
private Integer pageNo = PAGE_NO;
@Schema(description = "每页条数,最大值为 100", requiredMode = Schema.RequiredMode.REQUIRED, example = "10")
@NotNull(message = "每页条数不能为空")
@Min(value = 1, message = "每页条数最小值为 1")
@Max(value = 100, message = "每页条数最大值为 100")
private Integer pageSize = PAGE_SIZE;
}

View File

@ -0,0 +1,41 @@
package cn.iocoder.yudao.framework.common.pojo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@Schema(description = "分页结果")
@Data
public final class PageResult<T> implements Serializable {
@Schema(description = "数据", requiredMode = Schema.RequiredMode.REQUIRED)
private List<T> list;
@Schema(description = "总量", requiredMode = Schema.RequiredMode.REQUIRED)
private Long total;
public PageResult() {
}
public PageResult(List<T> list, Long total) {
this.list = list;
this.total = total;
}
public PageResult(Long total) {
this.list = new ArrayList<>();
this.total = total;
}
public static <T> PageResult<T> empty() {
return new PageResult<>(0L);
}
public static <T> PageResult<T> empty(Long total) {
return new PageResult<>(total);
}
}

View File

@ -0,0 +1,19 @@
package cn.iocoder.yudao.framework.common.pojo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import java.util.List;
@Schema(description = "可排序的分页参数")
@Data
@EqualsAndHashCode(callSuper = true)
@ToString(callSuper = true)
public class SortablePageParam extends PageParam {
@Schema(description = "排序字段")
private List<SortingField> sortingFields;
}

View File

@ -0,0 +1,37 @@
package cn.iocoder.yudao.framework.common.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
/**
* 排序字段 DTO
*
* 类名加了 ing 的原因是,避免和 ES SortField 重名。
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SortingField implements Serializable {
/**
* 顺序 - 升序
*/
public static final String ORDER_ASC = "asc";
/**
* 顺序 - 降序
*/
public static final String ORDER_DESC = "desc";
/**
* 字段
*/
private String field;
/**
* 顺序
*/
private String order;
}

View File

@ -0,0 +1,29 @@
package cn.iocoder.yudao.framework.common.util.cache;
import com.alibaba.ttl.threadpool.TtlExecutors;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import java.time.Duration;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
/**
* Cache 工具类
*
* @author 芋道源码
*/
public class CacheUtils {
public static <K, V> LoadingCache<K, V> buildAsyncReloadingCache(Duration duration, CacheLoader<K, V> loader) {
Executor executor = Executors.newCachedThreadPool( // TODO 芋艿:可能要思考下,未来要不要做成可配置
TtlExecutors.getDefaultDisableInheritableThreadFactory()); // TTL 保证 ThreadLocal 可以透传
return CacheBuilder.newBuilder()
// 只阻塞当前数据加载线程,其他线程返回旧值
.refreshAfterWrite(duration)
// 通过 asyncReloading 实现全异步加载,包括 refreshAfterWrite 被阻塞的加载线程
.build(CacheLoader.asyncReloading(loader, executor));
}
}

View File

@ -0,0 +1,58 @@
package cn.iocoder.yudao.framework.common.util.collection;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.collection.IterUtil;
import cn.hutool.core.util.ArrayUtil;
import java.util.Collection;
import java.util.function.Consumer;
import java.util.function.Function;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
/**
* Array 工具类
*
* @author 芋道源码
*/
public class ArrayUtils {
/**
* 将 object 和 newElements 合并成一个数组
*
* @param object 对象
* @param newElements 数组
* @param <T> 泛型
* @return 结果数组
*/
@SafeVarargs
public static <T> Consumer<T>[] append(Consumer<T> object, Consumer<T>... newElements) {
if (object == null) {
return newElements;
}
Consumer<T>[] result = ArrayUtil.newArray(Consumer.class, 1 + newElements.length);
result[0] = object;
System.arraycopy(newElements, 0, result, 1, newElements.length);
return result;
}
public static <T, V> V[] toArray(Collection<T> from, Function<T, V> mapper) {
return toArray(convertList(from, mapper));
}
@SuppressWarnings("unchecked")
public static <T> T[] toArray(Collection<T> from) {
if (CollectionUtil.isEmpty(from)) {
return (T[]) (new Object[0]);
}
return ArrayUtil.toArray(from, (Class<T>) IterUtil.getElementType(from.iterator()));
}
public static <T> T get(T[] array, int index) {
if (null == array || index >= array.length) {
return null;
}
return array[index];
}
}

View File

@ -0,0 +1,318 @@
package cn.iocoder.yudao.framework.common.util.collection;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.ArrayUtil;
import com.google.common.collect.ImmutableMap;
import java.util.*;
import java.util.function.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static java.util.Arrays.asList;
/**
* Collection 工具类
*
* @author 芋道源码
*/
public class CollectionUtils {
public static boolean containsAny(Object source, Object... targets) {
return asList(targets).contains(source);
}
public static boolean isAnyEmpty(Collection<?>... collections) {
return Arrays.stream(collections).anyMatch(CollectionUtil::isEmpty);
}
public static <T> boolean anyMatch(Collection<T> from, Predicate<T> predicate) {
return from.stream().anyMatch(predicate);
}
public static <T> List<T> filterList(Collection<T> from, Predicate<T> predicate) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return from.stream().filter(predicate).collect(Collectors.toList());
}
public static <T, R> List<T> distinct(Collection<T> from, Function<T, R> keyMapper) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return distinct(from, keyMapper, (t1, t2) -> t1);
}
public static <T, R> List<T> distinct(Collection<T> from, Function<T, R> keyMapper, BinaryOperator<T> cover) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return new ArrayList<>(convertMap(from, keyMapper, Function.identity(), cover).values());
}
public static <T, U> List<U> convertList(T[] from, Function<T, U> func) {
if (ArrayUtil.isEmpty(from)) {
return new ArrayList<>();
}
return convertList(Arrays.asList(from), func);
}
public static <T, U> List<U> convertList(Collection<T> from, Function<T, U> func) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return from.stream().map(func).filter(Objects::nonNull).collect(Collectors.toList());
}
public static <T, U> List<U> convertList(Collection<T> from, Function<T, U> func, Predicate<T> filter) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return from.stream().filter(filter).map(func).filter(Objects::nonNull).collect(Collectors.toList());
}
public static <T, U> List<U> convertListByFlatMap(Collection<T> from,
Function<T, ? extends Stream<? extends U>> func) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return from.stream().flatMap(func).filter(Objects::nonNull).collect(Collectors.toList());
}
public static <T, U, R> List<R> convertListByFlatMap(Collection<T> from,
Function<? super T, ? extends U> mapper,
Function<U, ? extends Stream<? extends R>> func) {
if (CollUtil.isEmpty(from)) {
return new ArrayList<>();
}
return from.stream().map(mapper).flatMap(func).filter(Objects::nonNull).collect(Collectors.toList());
}
public static <K, V> List<V> mergeValuesFromMap(Map<K, List<V>> map) {
return map.values()
.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
}
public static <T, U> Set<U> convertSet(Collection<T> from, Function<T, U> func) {
if (CollUtil.isEmpty(from)) {
return new HashSet<>();
}
return from.stream().map(func).filter(Objects::nonNull).collect(Collectors.toSet());
}
public static <T, U> Set<U> convertSet(Collection<T> from, Function<T, U> func, Predicate<T> filter) {
if (CollUtil.isEmpty(from)) {
return new HashSet<>();
}
return from.stream().filter(filter).map(func).filter(Objects::nonNull).collect(Collectors.toSet());
}
public static <T, K> Map<K, T> convertMapByFilter(Collection<T> from, Predicate<T> filter, Function<T, K> keyFunc) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return from.stream().filter(filter).collect(Collectors.toMap(keyFunc, v -> v));
}
public static <T, U> Set<U> convertSetByFlatMap(Collection<T> from,
Function<T, ? extends Stream<? extends U>> func) {
if (CollUtil.isEmpty(from)) {
return new HashSet<>();
}
return from.stream().flatMap(func).filter(Objects::nonNull).collect(Collectors.toSet());
}
public static <T, U, R> Set<R> convertSetByFlatMap(Collection<T> from,
Function<? super T, ? extends U> mapper,
Function<U, ? extends Stream<? extends R>> func) {
if (CollUtil.isEmpty(from)) {
return new HashSet<>();
}
return from.stream().map(mapper).flatMap(func).filter(Objects::nonNull).collect(Collectors.toSet());
}
public static <T, K> Map<K, T> convertMap(Collection<T> from, Function<T, K> keyFunc) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return convertMap(from, keyFunc, Function.identity());
}
public static <T, K> Map<K, T> convertMap(Collection<T> from, Function<T, K> keyFunc, Supplier<? extends Map<K, T>> supplier) {
if (CollUtil.isEmpty(from)) {
return supplier.get();
}
return convertMap(from, keyFunc, Function.identity(), supplier);
}
public static <T, K, V> Map<K, V> convertMap(Collection<T> from, Function<T, K> keyFunc, Function<T, V> valueFunc) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return convertMap(from, keyFunc, valueFunc, (v1, v2) -> v1);
}
public static <T, K, V> Map<K, V> convertMap(Collection<T> from, Function<T, K> keyFunc, Function<T, V> valueFunc, BinaryOperator<V> mergeFunction) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return convertMap(from, keyFunc, valueFunc, mergeFunction, HashMap::new);
}
public static <T, K, V> Map<K, V> convertMap(Collection<T> from, Function<T, K> keyFunc, Function<T, V> valueFunc, Supplier<? extends Map<K, V>> supplier) {
if (CollUtil.isEmpty(from)) {
return supplier.get();
}
return convertMap(from, keyFunc, valueFunc, (v1, v2) -> v1, supplier);
}
public static <T, K, V> Map<K, V> convertMap(Collection<T> from, Function<T, K> keyFunc, Function<T, V> valueFunc, BinaryOperator<V> mergeFunction, Supplier<? extends Map<K, V>> supplier) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return from.stream().collect(Collectors.toMap(keyFunc, valueFunc, mergeFunction, supplier));
}
public static <T, K> Map<K, List<T>> convertMultiMap(Collection<T> from, Function<T, K> keyFunc) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return from.stream().collect(Collectors.groupingBy(keyFunc, Collectors.mapping(t -> t, Collectors.toList())));
}
public static <T, K, V> Map<K, List<V>> convertMultiMap(Collection<T> from, Function<T, K> keyFunc, Function<T, V> valueFunc) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return from.stream()
.collect(Collectors.groupingBy(keyFunc, Collectors.mapping(valueFunc, Collectors.toList())));
}
// 暂时没想好名字,先以 2 结尾噶
public static <T, K, V> Map<K, Set<V>> convertMultiMap2(Collection<T> from, Function<T, K> keyFunc, Function<T, V> valueFunc) {
if (CollUtil.isEmpty(from)) {
return new HashMap<>();
}
return from.stream().collect(Collectors.groupingBy(keyFunc, Collectors.mapping(valueFunc, Collectors.toSet())));
}
public static <T, K> Map<K, T> convertImmutableMap(Collection<T> from, Function<T, K> keyFunc) {
if (CollUtil.isEmpty(from)) {
return Collections.emptyMap();
}
ImmutableMap.Builder<K, T> builder = ImmutableMap.builder();
from.forEach(item -> builder.put(keyFunc.apply(item), item));
return builder.build();
}
/**
* 对比老、新两个列表,找出新增、修改、删除的数据
*
* @param oldList 老列表
* @param newList 新列表
* @param sameFunc 对比函数,返回 true 表示相同,返回 false 表示不同
* 注意same 是通过每个元素的“标识”,判断它们是不是同一个数据
* @return [新增列表、修改列表、删除列表]
*/
public static <T> List<List<T>> diffList(Collection<T> oldList, Collection<T> newList,
BiFunction<T, T, Boolean> sameFunc) {
List<T> createList = new LinkedList<>(newList); // 默认都认为是新增的,后续会进行移除
List<T> updateList = new ArrayList<>();
List<T> deleteList = new ArrayList<>();
// 通过以 oldList 为主遍历,找出 updateList 和 deleteList
for (T oldObj : oldList) {
// 1. 寻找是否有匹配的
T foundObj = null;
for (Iterator<T> iterator = createList.iterator(); iterator.hasNext(); ) {
T newObj = iterator.next();
// 1.1 不匹配,则直接跳过
if (!sameFunc.apply(oldObj, newObj)) {
continue;
}
// 1.2 匹配,则移除,并结束寻找
iterator.remove();
foundObj = newObj;
break;
}
// 2. 匹配添加到 updateList不匹配则添加到 deleteList 中
if (foundObj != null) {
updateList.add(foundObj);
} else {
deleteList.add(oldObj);
}
}
return asList(createList, updateList, deleteList);
}
public static boolean containsAny(Collection<?> source, Collection<?> candidates) {
return org.springframework.util.CollectionUtils.containsAny(source, candidates);
}
public static <T> T getFirst(List<T> from) {
return !CollectionUtil.isEmpty(from) ? from.get(0) : null;
}
public static <T> T findFirst(Collection<T> from, Predicate<T> predicate) {
return findFirst(from, predicate, Function.identity());
}
public static <T, U> U findFirst(Collection<T> from, Predicate<T> predicate, Function<T, U> func) {
if (CollUtil.isEmpty(from)) {
return null;
}
return from.stream().filter(predicate).findFirst().map(func).orElse(null);
}
public static <T, V extends Comparable<? super V>> V getMaxValue(Collection<T> from, Function<T, V> valueFunc) {
if (CollUtil.isEmpty(from)) {
return null;
}
assert !from.isEmpty(); // 断言,避免告警
T t = from.stream().max(Comparator.comparing(valueFunc)).get();
return valueFunc.apply(t);
}
public static <T, V extends Comparable<? super V>> V getMinValue(List<T> from, Function<T, V> valueFunc) {
if (CollUtil.isEmpty(from)) {
return null;
}
assert from.size() > 0; // 断言,避免告警
T t = from.stream().min(Comparator.comparing(valueFunc)).get();
return valueFunc.apply(t);
}
public static <T, V extends Comparable<? super V>> V getSumValue(List<T> from, Function<T, V> valueFunc,
BinaryOperator<V> accumulator) {
return getSumValue(from, valueFunc, accumulator, null);
}
public static <T, V extends Comparable<? super V>> V getSumValue(Collection<T> from, Function<T, V> valueFunc,
BinaryOperator<V> accumulator, V defaultValue) {
if (CollUtil.isEmpty(from)) {
return defaultValue;
}
assert !from.isEmpty(); // 断言,避免告警
return from.stream().map(valueFunc).filter(Objects::nonNull).reduce(accumulator).orElse(defaultValue);
}
public static <T> void addIfNotNull(Collection<T> coll, T item) {
if (item == null) {
return;
}
coll.add(item);
}
public static <T> Collection<T> singleton(T obj) {
return obj == null ? Collections.emptyList() : Collections.singleton(obj);
}
public static <T> List<T> newArrayList(List<List<T>> list) {
return list.stream().flatMap(Collection::stream).collect(Collectors.toList());
}
}

View File

@ -0,0 +1,66 @@
package cn.iocoder.yudao.framework.common.util.collection;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.iocoder.yudao.framework.common.core.KeyValue;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
/**
* Map 工具类
*
* @author 芋道源码
*/
public class MapUtils {
/**
* 从哈希表表中,获得 keys 对应的所有 value 数组
*
* @param multimap 哈希表
* @param keys keys
* @return value 数组
*/
public static <K, V> List<V> getList(Multimap<K, V> multimap, Collection<K> keys) {
List<V> result = new ArrayList<>();
keys.forEach(k -> {
Collection<V> values = multimap.get(k);
if (CollectionUtil.isEmpty(values)) {
return;
}
result.addAll(values);
});
return result;
}
/**
* 从哈希表查找到 key 对应的 value然后进一步处理
* 注意,如果查找到的 value 为 null 时,不进行处理
*
* @param map 哈希表
* @param key key
* @param consumer 进一步处理的逻辑
*/
public static <K, V> void findAndThen(Map<K, V> map, K key, Consumer<V> consumer) {
if (CollUtil.isEmpty(map)) {
return;
}
V value = map.get(key);
if (value == null) {
return;
}
consumer.accept(value);
}
public static <K, V> Map<K, V> convertMap(List<KeyValue<K, V>> keyValues) {
Map<K, V> map = Maps.newLinkedHashMapWithExpectedSize(keyValues.size());
keyValues.forEach(keyValue -> map.put(keyValue.getKey(), keyValue.getValue()));
return map;
}
}

View File

@ -0,0 +1,19 @@
package cn.iocoder.yudao.framework.common.util.collection;
import cn.hutool.core.collection.CollUtil;
import java.util.Set;
/**
* Set 工具类
*
* @author 芋道源码
*/
public class SetUtils {
@SafeVarargs
public static <T> Set<T> asSet(T... objs) {
return CollUtil.newHashSet(objs);
}
}

View File

@ -0,0 +1,190 @@
package cn.iocoder.yudao.framework.common.util.date;
import cn.hutool.core.date.LocalDateTimeUtil;
import java.time.*;
import java.util.Calendar;
import java.util.Date;
/**
* 时间工具类
*
* @author 芋道源码
*/
public class DateUtils {
/**
* 时区 - 默认
*/
public static final String TIME_ZONE_DEFAULT = "GMT+8";
/**
* 秒转换成毫秒
*/
public static final long SECOND_MILLIS = 1000;
public static final String FORMAT_YEAR_MONTH_DAY = "yyyy-MM-dd";
public static final String FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND = "yyyy-MM-dd HH:mm:ss";
public static final String FORMAT_HOUR_MINUTE_SECOND = "HH:mm:ss";
/**
* 将 LocalDateTime 转换成 Date
*
* @param date LocalDateTime
* @return LocalDateTime
*/
public static Date of(LocalDateTime date) {
if (date == null) {
return null;
}
// 将此日期时间与时区相结合以创建 ZonedDateTime
ZonedDateTime zonedDateTime = date.atZone(ZoneId.systemDefault());
// 本地时间线 LocalDateTime 到即时时间线 Instant 时间戳
Instant instant = zonedDateTime.toInstant();
// UTC时间(世界协调时间,UTC + 00:00)转北京(北京,UTC + 8:00)时间
return Date.from(instant);
}
/**
* 将 Date 转换成 LocalDateTime
*
* @param date Date
* @return LocalDateTime
*/
public static LocalDateTime of(Date date) {
if (date == null) {
return null;
}
// 转为时间戳
Instant instant = date.toInstant();
// UTC时间(世界协调时间,UTC + 00:00)转北京(北京,UTC + 8:00)时间
return LocalDateTime.ofInstant(instant, ZoneId.systemDefault());
}
public static Date addTime(Duration duration) {
return new Date(System.currentTimeMillis() + duration.toMillis());
}
public static boolean isExpired(Date time) {
return System.currentTimeMillis() > time.getTime();
}
public static boolean isExpired(LocalDateTime time) {
LocalDateTime now = LocalDateTime.now();
return now.isAfter(time);
}
public static long diff(Date endTime, Date startTime) {
return endTime.getTime() - startTime.getTime();
}
/**
* 创建指定时间
*
* @param year 年
* @param mouth 月
* @param day 日
* @return 指定时间
*/
public static Date buildTime(int year, int mouth, int day) {
return buildTime(year, mouth, day, 0, 0, 0);
}
/**
* 创建指定时间
*
* @param year 年
* @param mouth 月
* @param day 日
* @param hour 小时
* @param minute 分钟
* @param second 秒
* @return 指定时间
*/
public static Date buildTime(int year, int mouth, int day,
int hour, int minute, int second) {
Calendar calendar = Calendar.getInstance();
calendar.set(Calendar.YEAR, year);
calendar.set(Calendar.MONTH, mouth - 1);
calendar.set(Calendar.DAY_OF_MONTH, day);
calendar.set(Calendar.HOUR_OF_DAY, hour);
calendar.set(Calendar.MINUTE, minute);
calendar.set(Calendar.SECOND, second);
calendar.set(Calendar.MILLISECOND, 0); // 一般情况下,都是 0 毫秒
return calendar.getTime();
}
public static Date max(Date a, Date b) {
if (a == null) {
return b;
}
if (b == null) {
return a;
}
return a.compareTo(b) > 0 ? a : b;
}
public static LocalDateTime max(LocalDateTime a, LocalDateTime b) {
if (a == null) {
return b;
}
if (b == null) {
return a;
}
return a.isAfter(b) ? a : b;
}
/**
* 计算当期时间相差的日期
*
* @param field 日历字段.<br/>eg:Calendar.MONTH,Calendar.DAY_OF_MONTH,<br/>Calendar.HOUR_OF_DAY等.
* @param amount 相差的数值
* @return 计算后的日志
*/
public static Date addDate(int field, int amount) {
return addDate(null, field, amount);
}
/**
* 计算当期时间相差的日期
*
* @param date 设置时间
* @param field 日历字段 例如说,{@link Calendar#DAY_OF_MONTH} 等
* @param amount 相差的数值
* @return 计算后的日志
*/
public static Date addDate(Date date, int field, int amount) {
if (amount == 0) {
return date;
}
Calendar c = Calendar.getInstance();
if (date != null) {
c.setTime(date);
}
c.add(field, amount);
return c.getTime();
}
/**
* 是否今天
*
* @param date 日期
* @return 是否
*/
public static boolean isToday(LocalDateTime date) {
return LocalDateTimeUtil.isSameDay(date, LocalDateTime.now());
}
/**
* 是否昨天
*
* @param date 日期
* @return 是否
*/
public static boolean isYesterday(LocalDateTime date) {
return LocalDateTimeUtil.isSameDay(date, LocalDateTime.now().minusDays(1));
}
}

View File

@ -0,0 +1,171 @@
package cn.iocoder.yudao.framework.common.util.date;
import cn.hutool.core.date.LocalDateTimeUtil;
import java.time.Duration;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalAdjusters;
/**
* 时间工具类,用于 {@link java.time.LocalDateTime}
*
* @author 芋道源码
*/
public class LocalDateTimeUtils {
/**
* 空的 LocalDateTime 对象,主要用于 DB 唯一索引的默认值
*/
public static LocalDateTime EMPTY = buildTime(1970, 1, 1);
public static LocalDateTime addTime(Duration duration) {
return LocalDateTime.now().plus(duration);
}
public static LocalDateTime minusTime(Duration duration) {
return LocalDateTime.now().minus(duration);
}
public static boolean beforeNow(LocalDateTime date) {
return date.isBefore(LocalDateTime.now());
}
public static boolean afterNow(LocalDateTime date) {
return date.isAfter(LocalDateTime.now());
}
/**
* 创建指定时间
*
* @param year 年
* @param mouth 月
* @param day 日
* @return 指定时间
*/
public static LocalDateTime buildTime(int year, int mouth, int day) {
return LocalDateTime.of(year, mouth, day, 0, 0, 0);
}
public static LocalDateTime[] buildBetweenTime(int year1, int mouth1, int day1,
int year2, int mouth2, int day2) {
return new LocalDateTime[]{buildTime(year1, mouth1, day1), buildTime(year2, mouth2, day2)};
}
/**
* 判断当前时间是否在该时间范围内
*
* @param startTime 开始时间
* @param endTime 结束时间
* @return 是否
*/
public static boolean isBetween(LocalDateTime startTime, LocalDateTime endTime) {
if (startTime == null || endTime == null) {
return false;
}
return LocalDateTimeUtil.isIn(LocalDateTime.now(), startTime, endTime);
}
/**
* 判断当前时间是否在该时间范围内
*
* @param startTime 开始时间
* @param endTime 结束时间
* @return 是否
*/
public static boolean isBetween(String startTime, String endTime) {
if (startTime == null || endTime == null) {
return false;
}
LocalDate nowDate = LocalDate.now();
return LocalDateTimeUtil.isIn(LocalDateTime.now(),
LocalDateTime.of(nowDate, LocalTime.parse(startTime)),
LocalDateTime.of(nowDate, LocalTime.parse(endTime)));
}
/**
* 判断时间段是否重叠
*
* @param startTime1 开始 time1
* @param endTime1 结束 time1
* @param startTime2 开始 time2
* @param endTime2 结束 time2
* @return 重叠true 不重叠false
*/
public static boolean isOverlap(LocalTime startTime1, LocalTime endTime1, LocalTime startTime2, LocalTime endTime2) {
LocalDate nowDate = LocalDate.now();
return LocalDateTimeUtil.isOverlap(LocalDateTime.of(nowDate, startTime1), LocalDateTime.of(nowDate, endTime1),
LocalDateTime.of(nowDate, startTime2), LocalDateTime.of(nowDate, endTime2));
}
/**
* 获取指定日期所在的月份的开始时间
* 例如2023-09-30 00:00:00,000
*
* @param date 日期
* @return 月份的开始时间
*/
public static LocalDateTime beginOfMonth(LocalDateTime date) {
return date.with(TemporalAdjusters.firstDayOfMonth()).with(LocalTime.MIN);
}
/**
* 获取指定日期所在的月份的最后时间
* 例如2023-09-30 23:59:59,999
*
* @param date 日期
* @return 月份的结束时间
*/
public static LocalDateTime endOfMonth(LocalDateTime date) {
return date.with(TemporalAdjusters.lastDayOfMonth()).with(LocalTime.MAX);
}
/**
* 获取指定日期到现在过了几天,如果指定日期在当前日期之后,获取结果为负
*
* @param dateTime 日期
* @return 相差天数
*/
public static Long between(LocalDateTime dateTime) {
return LocalDateTimeUtil.between(dateTime, LocalDateTime.now(), ChronoUnit.DAYS);
}
/**
* 获取今天的开始时间
*
* @return 今天
*/
public static LocalDateTime getToday() {
return LocalDateTimeUtil.beginOfDay(LocalDateTime.now());
}
/**
* 获取昨天的开始时间
*
* @return 昨天
*/
public static LocalDateTime getYesterday() {
return LocalDateTimeUtil.beginOfDay(LocalDateTime.now().minusDays(1));
}
/**
* 获取本月的开始时间
*
* @return 本月
*/
public static LocalDateTime getMonth() {
return beginOfMonth(LocalDateTime.now());
}
/**
* 获取本年的开始时间
*
* @return 本年
*/
public static LocalDateTime getYear() {
return LocalDateTime.now().with(TemporalAdjusters.firstDayOfYear()).with(LocalTime.MIN);
}
}

View File

@ -0,0 +1,126 @@
package cn.iocoder.yudao.framework.common.util.http;
import cn.hutool.core.codec.Base64;
import cn.hutool.core.map.TableMap;
import cn.hutool.core.net.url.UrlBuilder;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.http.HttpServletRequest;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Map;
/**
* HTTP 工具类
*
* @author 芋道源码
*/
public class HttpUtils {
@SuppressWarnings("unchecked")
public static String replaceUrlQuery(String url, String key, String value) {
UrlBuilder builder = UrlBuilder.of(url, Charset.defaultCharset());
// 先移除
TableMap<CharSequence, CharSequence> query = (TableMap<CharSequence, CharSequence>)
ReflectUtil.getFieldValue(builder.getQuery(), "query");
query.remove(key);
// 后添加
builder.addQuery(key, value);
return builder.build();
}
private String append(String base, Map<String, ?> query, boolean fragment) {
return append(base, query, null, fragment);
}
/**
* 拼接 URL
*
* copy from Spring Security OAuth2 的 AuthorizationEndpoint 类的 append 方法
*
* @param base 基础 URL
* @param query 查询参数
* @param keys query 的 key对应的原本的 key 的映射。例如说 query 里有个 key 是 xx实际它的 key 是 extra_xx则通过 keys 里添加这个映射
* @param fragment URL 的 fragment即拼接到 # 中
* @return 拼接后的 URL
*/
public static String append(String base, Map<String, ?> query, Map<String, String> keys, boolean fragment) {
UriComponentsBuilder template = UriComponentsBuilder.newInstance();
UriComponentsBuilder builder = UriComponentsBuilder.fromUriString(base);
URI redirectUri;
try {
// assume it's encoded to start with (if it came in over the wire)
redirectUri = builder.build(true).toUri();
} catch (Exception e) {
// ... but allow client registrations to contain hard-coded non-encoded values
redirectUri = builder.build().toUri();
builder = UriComponentsBuilder.fromUri(redirectUri);
}
template.scheme(redirectUri.getScheme()).port(redirectUri.getPort()).host(redirectUri.getHost())
.userInfo(redirectUri.getUserInfo()).path(redirectUri.getPath());
if (fragment) {
StringBuilder values = new StringBuilder();
if (redirectUri.getFragment() != null) {
String append = redirectUri.getFragment();
values.append(append);
}
for (String key : query.keySet()) {
if (values.length() > 0) {
values.append("&");
}
String name = key;
if (keys != null && keys.containsKey(key)) {
name = keys.get(key);
}
values.append(name).append("={").append(key).append("}");
}
if (values.length() > 0) {
template.fragment(values.toString());
}
UriComponents encoded = template.build().expand(query).encode();
builder.fragment(encoded.getFragment());
} else {
for (String key : query.keySet()) {
String name = key;
if (keys != null && keys.containsKey(key)) {
name = keys.get(key);
}
template.queryParam(name, "{" + key + "}");
}
template.fragment(redirectUri.getFragment());
UriComponents encoded = template.build().expand(query).encode();
builder.query(encoded.getQuery());
}
return builder.build().toUriString();
}
public static String[] obtainBasicAuthorization(HttpServletRequest request) {
String clientId;
String clientSecret;
// 先从 Header 中获取
String authorization = request.getHeader("Authorization");
authorization = StrUtil.subAfter(authorization, "Basic ", true);
if (StringUtils.hasText(authorization)) {
authorization = Base64.decodeStr(authorization);
clientId = StrUtil.subBefore(authorization, ":", false);
clientSecret = StrUtil.subAfter(authorization, ":", false);
// 再从 Param 中获取
} else {
clientId = request.getParameter("client_id");
clientSecret = request.getParameter("client_secret");
}
// 如果两者非空,则返回
if (StrUtil.isNotEmpty(clientId) && StrUtil.isNotEmpty(clientSecret)) {
return new String[]{clientId, clientSecret};
}
return null;
}
}

View File

@ -0,0 +1,84 @@
package cn.iocoder.yudao.framework.common.util.io;
import cn.hutool.core.io.FileTypeUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.file.FileNameUtil;
import cn.hutool.core.util.IdUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.crypto.digest.DigestUtil;
import lombok.SneakyThrows;
import java.io.ByteArrayInputStream;
import java.io.File;
/**
* 文件工具类
*
* @author 芋道源码
*/
public class FileUtils {
/**
* 创建临时文件
* 该文件会在 JVM 退出时,进行删除
*
* @param data 文件内容
* @return 文件
*/
@SneakyThrows
public static File createTempFile(String data) {
File file = createTempFile();
// 写入内容
FileUtil.writeUtf8String(data, file);
return file;
}
/**
* 创建临时文件
* 该文件会在 JVM 退出时,进行删除
*
* @param data 文件内容
* @return 文件
*/
@SneakyThrows
public static File createTempFile(byte[] data) {
File file = createTempFile();
// 写入内容
FileUtil.writeBytes(data, file);
return file;
}
/**
* 创建临时文件,无内容
* 该文件会在 JVM 退出时,进行删除
*
* @return 文件
*/
@SneakyThrows
public static File createTempFile() {
// 创建文件,通过 UUID 保证唯一
File file = File.createTempFile(IdUtil.simpleUUID(), null);
// 标记 JVM 退出时,自动删除
file.deleteOnExit();
return file;
}
/**
* 生成文件路径
*
* @param content 文件内容
* @param originalName 原始文件名
* @return path唯一不可重复
*/
public static String generatePath(byte[] content, String originalName) {
String sha256Hex = DigestUtil.sha256Hex(content);
// 情况一:如果存在 name则优先使用 name 的后缀
if (StrUtil.isNotBlank(originalName)) {
String extName = FileNameUtil.extName(originalName);
return StrUtil.isBlank(extName) ? sha256Hex : sha256Hex + "." + extName;
}
// 情况二:基于 content 计算
return sha256Hex + '.' + FileTypeUtil.getType(new ByteArrayInputStream(content));
}
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.framework.common.util.io;
import cn.hutool.core.io.IORuntimeException;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import java.io.InputStream;
/**
* IO 工具类,用于 {@link cn.hutool.core.io.IoUtil} 缺失的方法
*
* @author 芋道源码
*/
public class IoUtils {
/**
* 从流中读取 UTF8 编码的内容
*
* @param in 输入流
* @param isClose 是否关闭
* @return 内容
* @throws IORuntimeException IO 异常
*/
public static String readUtf8(InputStream in, boolean isClose) throws IORuntimeException {
return StrUtil.utf8Str(IoUtil.read(in, isClose));
}
}

View File

@ -0,0 +1,202 @@
package cn.iocoder.yudao.framework.common.util.json;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
/**
* JSON 工具类
*
* @author 芋道源码
*/
@Slf4j
public class JsonUtils {
private static ObjectMapper objectMapper = new ObjectMapper();
static {
objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); // 忽略 null 值
objectMapper.registerModules(new JavaTimeModule()); // 解决 LocalDateTime 的序列化
}
/**
* 初始化 objectMapper 属性
* <p>
* 通过这样的方式,使用 Spring 创建的 ObjectMapper Bean
*
* @param objectMapper ObjectMapper 对象
*/
public static void init(ObjectMapper objectMapper) {
JsonUtils.objectMapper = objectMapper;
}
@SneakyThrows
public static String toJsonString(Object object) {
return objectMapper.writeValueAsString(object);
}
@SneakyThrows
public static byte[] toJsonByte(Object object) {
return objectMapper.writeValueAsBytes(object);
}
@SneakyThrows
public static String toJsonPrettyString(Object object) {
return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(object);
}
public static <T> T parseObject(String text, Class<T> clazz) {
if (StrUtil.isEmpty(text)) {
return null;
}
try {
return objectMapper.readValue(text, clazz);
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
public static <T> T parseObject(String text, String path, Class<T> clazz) {
if (StrUtil.isEmpty(text)) {
return null;
}
try {
JsonNode treeNode = objectMapper.readTree(text);
JsonNode pathNode = treeNode.path(path);
return objectMapper.readValue(pathNode.toString(), clazz);
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
public static <T> T parseObject(String text, Type type) {
if (StrUtil.isEmpty(text)) {
return null;
}
try {
return objectMapper.readValue(text, objectMapper.getTypeFactory().constructType(type));
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
/**
* 将字符串解析成指定类型的对象
* 使用 {@link #parseObject(String, Class)} 时,在@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) 的场景下,
* 如果 text 没有 class 属性,则会报错。此时,使用这个方法,可以解决。
*
* @param text 字符串
* @param clazz 类型
* @return 对象
*/
public static <T> T parseObject2(String text, Class<T> clazz) {
if (StrUtil.isEmpty(text)) {
return null;
}
return JSONUtil.toBean(text, clazz);
}
public static <T> T parseObject(byte[] bytes, Class<T> clazz) {
if (ArrayUtil.isEmpty(bytes)) {
return null;
}
try {
return objectMapper.readValue(bytes, clazz);
} catch (IOException e) {
log.error("json parse err,json:{}", bytes, e);
throw new RuntimeException(e);
}
}
public static <T> T parseObject(String text, TypeReference<T> typeReference) {
try {
return objectMapper.readValue(text, typeReference);
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
/**
* 解析 JSON 字符串成指定类型的对象,如果解析失败,则返回 null
*
* @param text 字符串
* @param typeReference 类型引用
* @return 指定类型的对象
*/
public static <T> T parseObjectQuietly(String text, TypeReference<T> typeReference) {
try {
return objectMapper.readValue(text, typeReference);
} catch (IOException e) {
return null;
}
}
public static <T> List<T> parseArray(String text, Class<T> clazz) {
if (StrUtil.isEmpty(text)) {
return new ArrayList<>();
}
try {
return objectMapper.readValue(text, objectMapper.getTypeFactory().constructCollectionType(List.class, clazz));
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
public static <T> List<T> parseArray(String text, String path, Class<T> clazz) {
if (StrUtil.isEmpty(text)) {
return null;
}
try {
JsonNode treeNode = objectMapper.readTree(text);
JsonNode pathNode = treeNode.path(path);
return objectMapper.readValue(pathNode.toString(), objectMapper.getTypeFactory().constructCollectionType(List.class, clazz));
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
public static JsonNode parseTree(String text) {
try {
return objectMapper.readTree(text);
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
public static JsonNode parseTree(byte[] text) {
try {
return objectMapper.readTree(text);
} catch (IOException e) {
log.error("json parse err,json:{}", text, e);
throw new RuntimeException(e);
}
}
public static boolean isJson(String text) {
return JSONUtil.isTypeJSON(text);
}
}

View File

@ -0,0 +1,30 @@
package cn.iocoder.yudao.framework.common.util.monitor;
import org.apache.skywalking.apm.toolkit.trace.TraceContext;
/**
* 链路追踪工具类
*
* 考虑到每个 starter 都需要用到该工具类,所以放到 common 模块下的 util 包下
*
* @author 芋道源码
*/
public class TracerUtils {
/**
* 私有化构造方法
*/
private TracerUtils() {
}
/**
* 获得链路追踪编号,直接返回 SkyWalking 的 TraceId。
* 如果不存在的话为空字符串!!!
*
* @return 链路追踪编号
*/
public static String getTraceId() {
return TraceContext.traceId();
}
}

View File

@ -0,0 +1,131 @@
package cn.iocoder.yudao.framework.common.util.number;
import cn.hutool.core.math.Money;
import cn.hutool.core.util.NumberUtil;
import java.math.BigDecimal;
import java.math.RoundingMode;
/**
* 金额工具类
*
* @author 芋道源码
*/
public class MoneyUtils {
/**
* 金额的小数位数
*/
private static final int PRICE_SCALE = 2;
/**
* 百分比对应的 BigDecimal 对象
*/
public static final BigDecimal PERCENT_100 = BigDecimal.valueOf(100);
/**
* 计算百分比金额,四舍五入
*
* @param price 金额
* @param rate 百分比,例如说 56.77% 则传入 56.77
* @return 百分比金额
*/
public static Integer calculateRatePrice(Integer price, Double rate) {
return calculateRatePrice(price, rate, 0, RoundingMode.HALF_UP).intValue();
}
/**
* 计算百分比金额,向下传入
*
* @param price 金额
* @param rate 百分比,例如说 56.77% 则传入 56.77
* @return 百分比金额
*/
public static Integer calculateRatePriceFloor(Integer price, Double rate) {
return calculateRatePrice(price, rate, 0, RoundingMode.FLOOR).intValue();
}
/**
* 计算百分比金额
*
* @param price 金额(单位分)
* @param count 数量
* @param percent 折扣(单位分),列如 60.2%,则传入 6020
* @return 商品总价
*/
public static Integer calculator(Integer price, Integer count, Integer percent) {
price = price * count;
if (percent == null) {
return price;
}
return MoneyUtils.calculateRatePriceFloor(price, (double) (percent / 100));
}
/**
* 计算百分比金额
*
* @param price 金额
* @param rate 百分比,例如说 56.77% 则传入 56.77
* @param scale 保留小数位数
* @param roundingMode 舍入模式
*/
public static BigDecimal calculateRatePrice(Number price, Number rate, int scale, RoundingMode roundingMode) {
return NumberUtil.toBigDecimal(price).multiply(NumberUtil.toBigDecimal(rate)) // 乘以
.divide(BigDecimal.valueOf(100), scale, roundingMode); // 除以 100
}
/**
* 分转元
*
* @param fen 分
* @return 元
*/
public static BigDecimal fenToYuan(int fen) {
return new Money(0, fen).getAmount();
}
/**
* 分转元(字符串)
*
* 例如说 fen 为 1 时,则结果为 0.01
*
* @param fen 分
* @return 元
*/
public static String fenToYuanStr(int fen) {
return new Money(0, fen).toString();
}
/**
* 金额相乘,默认进行四舍五入
*
* 位数:{@link #PRICE_SCALE}
*
* @param price 金额
* @param count 数量
* @return 金额相乘结果
*/
public static BigDecimal priceMultiply(BigDecimal price, BigDecimal count) {
if (price == null || count == null) {
return null;
}
return price.multiply(count).setScale(PRICE_SCALE, RoundingMode.HALF_UP);
}
/**
* 金额相乘(百分比),默认进行四舍五入
*
* 位数:{@link #PRICE_SCALE}
*
* @param price 金额
* @param percent 百分比
* @return 金额相乘结果
*/
public static BigDecimal priceMultiplyPercent(BigDecimal price, BigDecimal percent) {
if (price == null || percent == null) {
return null;
}
return price.multiply(percent).divide(PERCENT_100, PRICE_SCALE, RoundingMode.HALF_UP);
}
}

View File

@ -0,0 +1,60 @@
package cn.iocoder.yudao.framework.common.util.number;
import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.StrUtil;
import java.math.BigDecimal;
/**
* 数字的工具类,补全 {@link cn.hutool.core.util.NumberUtil} 的功能
*
* @author 芋道源码
*/
public class NumberUtils {
public static Long parseLong(String str) {
return StrUtil.isNotEmpty(str) ? Long.valueOf(str) : null;
}
/**
* 通过经纬度获取地球上两点之间的距离
*
* 参考 <<a href="https://gitee.com/dromara/hutool/blob/1caabb586b1f95aec66a21d039c5695df5e0f4c1/hutool-core/src/main/java/cn/hutool/core/util/DistanceUtil.java">DistanceUtil</a>> 实现,目前它已经被 hutool 删除
*
* @param lat1 经度1
* @param lng1 纬度1
* @param lat2 经度2
* @param lng2 纬度2
* @return 距离,单位:千米
*/
public static double getDistance(double lat1, double lng1, double lat2, double lng2) {
double radLat1 = lat1 * Math.PI / 180.0;
double radLat2 = lat2 * Math.PI / 180.0;
double a = radLat1 - radLat2;
double b = lng1 * Math.PI / 180.0 - lng2 * Math.PI / 180.0;
double distance = 2 * Math.asin(Math.sqrt(Math.pow(Math.sin(a / 2), 2)
+ Math.cos(radLat1) * Math.cos(radLat2)
* Math.pow(Math.sin(b / 2), 2)));
distance = distance * 6378.137;
distance = Math.round(distance * 10000d) / 10000d;
return distance;
}
/**
* 提供精确的乘法运算
*
* 和 hutool {@link NumberUtil#mul(BigDecimal...)} 的差别是,如果存在 null则返回 null
*
* @param values 多个被乘值
* @return 积
*/
public static BigDecimal mul(BigDecimal... values) {
for (BigDecimal value : values) {
if (value == null) {
return null;
}
}
return NumberUtil.mul(values);
}
}

View File

@ -0,0 +1,62 @@
package cn.iocoder.yudao.framework.common.util.object;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import java.util.List;
import java.util.function.Consumer;
/**
* Bean 工具类
*
* 1. 默认使用 {@link cn.hutool.core.bean.BeanUtil} 作为实现类,虽然不同 bean 工具的性能有差别,但是对绝大多数同学的项目,不用在意这点性能
* 2. 针对复杂的对象转换,可以搜参考 AuthConvert 实现,通过 mapstruct + default 配合实现
*
* @author 芋道源码
*/
public class BeanUtils {
public static <T> T toBean(Object source, Class<T> targetClass) {
return BeanUtil.toBean(source, targetClass);
}
public static <T> T toBean(Object source, Class<T> targetClass, Consumer<T> peek) {
T target = toBean(source, targetClass);
if (target != null) {
peek.accept(target);
}
return target;
}
public static <S, T> List<T> toBean(List<S> source, Class<T> targetType) {
if (source == null) {
return null;
}
return CollectionUtils.convertList(source, s -> toBean(s, targetType));
}
public static <S, T> List<T> toBean(List<S> source, Class<T> targetType, Consumer<T> peek) {
List<T> list = toBean(source, targetType);
if (list != null) {
list.forEach(peek);
}
return list;
}
public static <S, T> PageResult<T> toBean(PageResult<S> source, Class<T> targetType) {
return toBean(source, targetType, null);
}
public static <S, T> PageResult<T> toBean(PageResult<S> source, Class<T> targetType, Consumer<T> peek) {
if (source == null) {
return null;
}
List<T> list = toBean(source.getList(), targetType);
if (peek != null) {
list.forEach(peek);
}
return new PageResult<>(list, source.getTotal());
}
}

View File

@ -0,0 +1,63 @@
package cn.iocoder.yudao.framework.common.util.object;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReflectUtil;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.function.Consumer;
/**
* Object 工具类
*
* @author 芋道源码
*/
public class ObjectUtils {
/**
* 复制对象,并忽略 Id 编号
*
* @param object 被复制对象
* @param consumer 消费者,可以二次编辑被复制对象
* @return 复制后的对象
*/
public static <T> T cloneIgnoreId(T object, Consumer<T> consumer) {
T result = ObjectUtil.clone(object);
// 忽略 id 编号
Field field = ReflectUtil.getField(object.getClass(), "id");
if (field != null) {
ReflectUtil.setFieldValue(result, field, null);
}
// 二次编辑
if (result != null) {
consumer.accept(result);
}
return result;
}
public static <T extends Comparable<T>> T max(T obj1, T obj2) {
if (obj1 == null) {
return obj2;
}
if (obj2 == null) {
return obj1;
}
return obj1.compareTo(obj2) > 0 ? obj1 : obj2;
}
@SafeVarargs
public static <T> T defaultIfNull(T... array) {
for (T item : array) {
if (item != null) {
return item;
}
}
return null;
}
@SafeVarargs
public static <T> boolean equalsAny(T obj, T... array) {
return Arrays.asList(array).contains(obj);
}
}

View File

@ -0,0 +1,67 @@
package cn.iocoder.yudao.framework.common.util.object;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.func.Func1;
import cn.hutool.core.lang.func.LambdaUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.SortablePageParam;
import cn.iocoder.yudao.framework.common.pojo.SortingField;
import org.springframework.util.Assert;
import static java.util.Collections.singletonList;
/**
* {@link cn.iocoder.yudao.framework.common.pojo.PageParam} 工具类
*
* @author 芋道源码
*/
public class PageUtils {
private static final Object[] ORDER_TYPES = new String[]{SortingField.ORDER_ASC, SortingField.ORDER_DESC};
public static int getStart(PageParam pageParam) {
return (pageParam.getPageNo() - 1) * pageParam.getPageSize();
}
/**
* 构建排序字段(默认倒序)
*
* @param func 排序字段的 Lambda 表达式
* @param <T> 排序字段所属的类型
* @return 排序字段
*/
public static <T> SortingField buildSortingField(Func1<T, ?> func) {
return buildSortingField(func, SortingField.ORDER_DESC);
}
/**
* 构建排序字段
*
* @param func 排序字段的 Lambda 表达式
* @param order 排序类型 {@link SortingField#ORDER_ASC} {@link SortingField#ORDER_DESC}
* @param <T> 排序字段所属的类型
* @return 排序字段
*/
public static <T> SortingField buildSortingField(Func1<T, ?> func, String order) {
Assert.isTrue(ArrayUtil.contains(ORDER_TYPES, order), String.format("字段的排序类型只能是 %s/%s", ORDER_TYPES));
String fieldName = LambdaUtil.getFieldName(func);
return new SortingField(fieldName, order);
}
/**
* 构建默认的排序字段
* 如果排序字段为空,则设置排序字段;否则忽略
*
* @param sortablePageParam 排序分页查询参数
* @param func 排序字段的 Lambda 表达式
* @param <T> 排序字段所属的类型
*/
public static <T> void buildDefaultSortingField(SortablePageParam sortablePageParam, Func1<T, ?> func) {
if (sortablePageParam != null && CollUtil.isEmpty(sortablePageParam.getSortingFields())) {
sortablePageParam.setSortingFields(singletonList(buildSortingField(func)));
}
}
}

View File

@ -0,0 +1,7 @@
/**
* 对于工具类的选择,优先查找 Hutool 中有没对应的方法
* 如果没有,则自己封装对应的工具类,以 Utils 结尾,用于区分
*
* ps如果担心 Hutool 存在坑的问题,可以阅读 Hutool 的实现源码,以确保可靠性。并且,可以补充相关的单元测试。
*/
package cn.iocoder.yudao.framework.common.util;

View File

@ -0,0 +1,111 @@
package cn.iocoder.yudao.framework.common.util.servlet;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.servlet.ServletUtil;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import org.springframework.http.MediaType;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.URLEncoder;
import java.util.Map;
/**
* 客户端工具类
*
* @author 芋道源码
*/
public class ServletUtils {
/**
* 返回 JSON 字符串
*
* @param response 响应
* @param object 对象,会序列化成 JSON 字符串
*/
@SuppressWarnings("deprecation") // 必须使用 APPLICATION_JSON_UTF8_VALUE否则会乱码
public static void writeJSON(HttpServletResponse response, Object object) {
String content = JsonUtils.toJsonString(object);
ServletUtil.write(response, content, MediaType.APPLICATION_JSON_UTF8_VALUE);
}
/**
* 返回附件
*
* @param response 响应
* @param filename 文件名
* @param content 附件内容
*/
public static void writeAttachment(HttpServletResponse response, String filename, byte[] content) throws IOException {
// 设置 header 和 contentType
response.setHeader("Content-Disposition", "attachment;filename=" + URLEncoder.encode(filename, "UTF-8"));
response.setContentType(MediaType.APPLICATION_OCTET_STREAM_VALUE);
// 输出附件
IoUtil.write(response.getOutputStream(), false, content);
}
/**
* @param request 请求
* @return ua
*/
public static String getUserAgent(HttpServletRequest request) {
String ua = request.getHeader("User-Agent");
return ua != null ? ua : "";
}
/**
* 获得请求
*
* @return HttpServletRequest
*/
public static HttpServletRequest getRequest() {
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (!(requestAttributes instanceof ServletRequestAttributes)) {
return null;
}
return ((ServletRequestAttributes) requestAttributes).getRequest();
}
public static String getUserAgent() {
HttpServletRequest request = getRequest();
if (request == null) {
return null;
}
return getUserAgent(request);
}
public static String getClientIP() {
HttpServletRequest request = getRequest();
if (request == null) {
return null;
}
return ServletUtil.getClientIP(request);
}
public static boolean isJsonRequest(ServletRequest request) {
return StrUtil.startWithIgnoreCase(request.getContentType(), MediaType.APPLICATION_JSON_VALUE);
}
public static String getBody(HttpServletRequest request) {
return ServletUtil.getBody(request);
}
public static byte[] getBodyBytes(HttpServletRequest request) {
return ServletUtil.getBodyBytes(request);
}
public static String getClientIP(HttpServletRequest request) {
return ServletUtil.getClientIP(request);
}
public static Map<String, String> getParamMap(HttpServletRequest request) {
return ServletUtil.getParamMap(request);
}
}

View File

@ -0,0 +1,46 @@
package cn.iocoder.yudao.framework.common.util.spring;
import cn.hutool.core.bean.BeanUtil;
import org.springframework.aop.framework.AdvisedSupport;
import org.springframework.aop.framework.AopProxy;
import org.springframework.aop.support.AopUtils;
/**
* Spring AOP 工具类
*
* 参考波克尔 http://www.bubuko.com/infodetail-3471885.html 实现
*/
public class SpringAopUtils {
/**
* 获取代理的目标对象
*
* @param proxy 代理对象
* @return 目标对象
*/
public static Object getTarget(Object proxy) throws Exception {
// 不是代理对象
if (!AopUtils.isAopProxy(proxy)) {
return proxy;
}
// Jdk 代理
if (AopUtils.isJdkDynamicProxy(proxy)) {
return getJdkDynamicProxyTargetObject(proxy);
}
// Cglib 代理
return getCglibProxyTargetObject(proxy);
}
private static Object getCglibProxyTargetObject(Object proxy) throws Exception {
Object dynamicAdvisedInterceptor = BeanUtil.getFieldValue(proxy, "CGLIB$CALLBACK_0");
AdvisedSupport advisedSupport = (AdvisedSupport) BeanUtil.getFieldValue(dynamicAdvisedInterceptor, "advised");
return advisedSupport.getTargetSource().getTarget();
}
private static Object getJdkDynamicProxyTargetObject(Object proxy) throws Exception {
AopProxy aopProxy = (AopProxy) BeanUtil.getFieldValue(proxy, "h");
AdvisedSupport advisedSupport = (AdvisedSupport) BeanUtil.getFieldValue(aopProxy, "advised");
return advisedSupport.getTargetSource().getTarget();
}
}

View File

@ -0,0 +1,89 @@
package cn.iocoder.yudao.framework.common.util.spring;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ArrayUtil;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* Spring EL 表达式的工具类
*
* @author mashu
*/
public class SpringExpressionUtils {
/**
* Spring EL 表达式解析器
*/
private static final ExpressionParser EXPRESSION_PARSER = new SpelExpressionParser();
/**
* 参数名发现器
*/
private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer();
private SpringExpressionUtils() {
}
/**
* 从切面中,单个解析 EL 表达式的结果
*
* @param joinPoint 切面点
* @param expressionString EL 表达式数组
* @return 执行界面
*/
public static Object parseExpression(JoinPoint joinPoint, String expressionString) {
Map<String, Object> result = parseExpressions(joinPoint, Collections.singletonList(expressionString));
return result.get(expressionString);
}
/**
* 从切面中,批量解析 EL 表达式的结果
*
* @param joinPoint 切面点
* @param expressionStrings EL 表达式数组
* @return 结果key 为表达式value 为对应值
*/
public static Map<String, Object> parseExpressions(JoinPoint joinPoint, List<String> expressionStrings) {
// 如果为空,则不进行解析
if (CollUtil.isEmpty(expressionStrings)) {
return MapUtil.newHashMap();
}
// 第一步,构建解析的上下文 EvaluationContext
// 通过 joinPoint 获取被注解方法
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = methodSignature.getMethod();
// 使用 spring 的 ParameterNameDiscoverer 获取方法形参名数组
String[] paramNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method);
// Spring 的表达式上下文对象
EvaluationContext context = new StandardEvaluationContext();
// 给上下文赋值
if (ArrayUtil.isNotEmpty(paramNames)) {
Object[] args = joinPoint.getArgs();
for (int i = 0; i < paramNames.length; i++) {
context.setVariable(paramNames[i], args[i]);
}
}
// 第二步,逐个参数解析
Map<String, Object> result = MapUtil.newHashMap(expressionStrings.size(), true);
expressionStrings.forEach(key -> {
Object value = EXPRESSION_PARSER.parseExpression(key).getValue(context);
result.put(key, value);
});
return result;
}
}

View File

@ -0,0 +1,69 @@
package cn.iocoder.yudao.framework.common.util.string;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
/**
* 字符串工具类
*
* @author 芋道源码
*/
public class StrUtils {
public static String maxLength(CharSequence str, int maxLength) {
return StrUtil.maxLength(str, maxLength - 3); // -3 的原因,是该方法会补充 ... 恰好
}
/**
* 给定字符串是否以任何一个字符串开始
* 给定字符串和数组为空都返回 false
*
* @param str 给定字符串
* @param prefixes 需要检测的开始字符串
* @since 3.0.6
*/
public static boolean startWithAny(String str, Collection<String> prefixes) {
if (StrUtil.isEmpty(str) || ArrayUtil.isEmpty(prefixes)) {
return false;
}
for (CharSequence suffix : prefixes) {
if (StrUtil.startWith(str, suffix, false)) {
return true;
}
}
return false;
}
public static List<Long> splitToLong(String value, CharSequence separator) {
long[] longs = StrUtil.splitToLong(value, separator);
return Arrays.stream(longs).boxed().collect(Collectors.toList());
}
public static List<Integer> splitToInteger(String value, CharSequence separator) {
int[] integers = StrUtil.splitToInt(value, separator);
return Arrays.stream(integers).boxed().collect(Collectors.toList());
}
/**
* 移除字符串中,包含指定字符串的行
*
* @param content 字符串
* @param sequence 包含的字符串
* @return 移除后的字符串
*/
public static String removeLineContains(String content, String sequence) {
if (StrUtil.isEmpty(content) || StrUtil.isEmpty(sequence)) {
return content;
}
return Arrays.stream(content.split("\n"))
.filter(line -> !line.contains(sequence))
.collect(Collectors.joining("\n"));
}
}

View File

@ -0,0 +1,55 @@
package cn.iocoder.yudao.framework.common.util.validation;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import org.springframework.util.StringUtils;
import javax.validation.ConstraintViolation;
import javax.validation.ConstraintViolationException;
import javax.validation.Validation;
import javax.validation.Validator;
import java.util.Set;
import java.util.regex.Pattern;
/**
* 校验工具类
*
* @author 芋道源码
*/
public class ValidationUtils {
private static final Pattern PATTERN_MOBILE = Pattern.compile("^(?:(?:\\+|00)86)?1(?:(?:3[\\d])|(?:4[0,1,4-9])|(?:5[0-3,5-9])|(?:6[2,5-7])|(?:7[0-8])|(?:8[\\d])|(?:9[0-3,5-9]))\\d{8}$");
private static final Pattern PATTERN_URL = Pattern.compile("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]");
private static final Pattern PATTERN_XML_NCNAME = Pattern.compile("[a-zA-Z_][\\-_.0-9_a-zA-Z$]*");
public static boolean isMobile(String mobile) {
return StringUtils.hasText(mobile)
&& PATTERN_MOBILE.matcher(mobile).matches();
}
public static boolean isURL(String url) {
return StringUtils.hasText(url)
&& PATTERN_URL.matcher(url).matches();
}
public static boolean isXmlNCName(String str) {
return StringUtils.hasText(str)
&& PATTERN_XML_NCNAME.matcher(str).matches();
}
public static void validate(Object object, Class<?>... groups) {
Validator validator = Validation.buildDefaultValidatorFactory().getValidator();
Assert.notNull(validator);
validate(validator, object, groups);
}
public static void validate(Validator validator, Object object, Class<?>... groups) {
Set<ConstraintViolation<Object>> constraintViolations = validator.validate(object, groups);
if (CollUtil.isNotEmpty(constraintViolations)) {
throw new ConstraintViolationException(constraintViolations);
}
}
}

View File

@ -0,0 +1,35 @@
package cn.iocoder.yudao.framework.common.validation;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import javax.validation.Constraint;
import javax.validation.Payload;
import java.lang.annotation.*;
@Target({
ElementType.METHOD,
ElementType.FIELD,
ElementType.ANNOTATION_TYPE,
ElementType.CONSTRUCTOR,
ElementType.PARAMETER,
ElementType.TYPE_USE
})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Constraint(
validatedBy = {InEnumValidator.class, InEnumCollectionValidator.class}
)
public @interface InEnum {
/**
* @return 实现 EnumValuable 接口的
*/
Class<? extends IntArrayValuable> value();
String message() default "必须在指定范围 {value}";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}

View File

@ -0,0 +1,42 @@
package cn.iocoder.yudao.framework.common.validation;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class InEnumCollectionValidator implements ConstraintValidator<InEnum, Collection<Integer>> {
private List<Integer> values;
@Override
public void initialize(InEnum annotation) {
IntArrayValuable[] values = annotation.value().getEnumConstants();
if (values.length == 0) {
this.values = Collections.emptyList();
} else {
this.values = Arrays.stream(values[0].array()).boxed().collect(Collectors.toList());
}
}
@Override
public boolean isValid(Collection<Integer> list, ConstraintValidatorContext context) {
// 校验通过
if (CollUtil.containsAll(values, list)) {
return true;
}
// 校验不通过,自定义提示语句(因为,注解上的 value 是枚举类,无法获得枚举类的实际值)
context.disableDefaultConstraintViolation(); // 禁用默认的 message 的值
context.buildConstraintViolationWithTemplate(context.getDefaultConstraintMessageTemplate()
.replaceAll("\\{value}", CollUtil.join(list, ","))).addConstraintViolation(); // 重新添加错误提示语句
return false;
}
}

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.framework.common.validation;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class InEnumValidator implements ConstraintValidator<InEnum, Integer> {
private List<Integer> values;
@Override
public void initialize(InEnum annotation) {
IntArrayValuable[] values = annotation.value().getEnumConstants();
if (values.length == 0) {
this.values = Collections.emptyList();
} else {
this.values = Arrays.stream(values[0].array()).boxed().collect(Collectors.toList());
}
}
@Override
public boolean isValid(Integer value, ConstraintValidatorContext context) {
// 为空时,默认不校验,即认为通过
if (value == null) {
return true;
}
// 校验通过
if (values.contains(value)) {
return true;
}
// 校验不通过,自定义提示语句(因为,注解上的 value 是枚举类,无法获得枚举类的实际值)
context.disableDefaultConstraintViolation(); // 禁用默认的 message 的值
context.buildConstraintViolationWithTemplate(context.getDefaultConstraintMessageTemplate()
.replaceAll("\\{value}", values.toString())).addConstraintViolation(); // 重新添加错误提示语句
return false;
}
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.framework.common.validation;
import javax.validation.Constraint;
import javax.validation.Payload;
import java.lang.annotation.*;
@Target({
ElementType.METHOD,
ElementType.FIELD,
ElementType.ANNOTATION_TYPE,
ElementType.CONSTRUCTOR,
ElementType.PARAMETER,
ElementType.TYPE_USE
})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Constraint(
validatedBy = MobileValidator.class
)
public @interface Mobile {
String message() default "手机号格式不正确";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}

View File

@ -0,0 +1,25 @@
package cn.iocoder.yudao.framework.common.validation;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.util.validation.ValidationUtils;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
public class MobileValidator implements ConstraintValidator<Mobile, String> {
@Override
public void initialize(Mobile annotation) {
}
@Override
public boolean isValid(String value, ConstraintValidatorContext context) {
// 如果手机号为空,默认不校验,即校验通过
if (StrUtil.isEmpty(value)) {
return true;
}
// 校验手机
return ValidationUtils.isMobile(value);
}
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.framework.common.validation;
import javax.validation.Constraint;
import javax.validation.Payload;
import java.lang.annotation.*;
@Target({
ElementType.METHOD,
ElementType.FIELD,
ElementType.ANNOTATION_TYPE,
ElementType.CONSTRUCTOR,
ElementType.PARAMETER,
ElementType.TYPE_USE
})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Constraint(
validatedBy = TelephoneValidator.class
)
public @interface Telephone {
String message() default "电话格式不正确";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}

View File

@ -0,0 +1,25 @@
package cn.iocoder.yudao.framework.common.validation;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.PhoneUtil;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
public class TelephoneValidator implements ConstraintValidator<Telephone, String> {
@Override
public void initialize(Telephone annotation) {
}
@Override
public boolean isValid(String value, ConstraintValidatorContext context) {
// 如果手机号为空,默认不校验,即校验通过
if (CharSequenceUtil.isEmpty(value)) {
return true;
}
// 校验手机
return PhoneUtil.isTel(value) || PhoneUtil.isPhone(value);
}
}

View File

@ -0,0 +1,4 @@
/**
* 使用 Hibernate Validator 实现参数校验
*/
package cn.iocoder.yudao.framework.common.validation;

View File

@ -0,0 +1,64 @@
package cn.iocoder.yudao.framework.common.util.collection;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.function.BiFunction;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* {@link CollectionUtils} 的单元测试
*/
public class CollectionUtilsTest {
@Data
@AllArgsConstructor
private static class Dog {
private Integer id;
private String name;
private String code;
}
@Test
public void testDiffList() {
// 准备参数
Collection<Dog> oldList = Arrays.asList(
new Dog(1, "花花", "hh"),
new Dog(2, "旺财", "wc")
);
Collection<Dog> newList = Arrays.asList(
new Dog(null, "花花2", "hh"),
new Dog(null, "小白", "xb")
);
BiFunction<Dog, Dog, Boolean> sameFunc = (oldObj, newObj) -> {
boolean same = oldObj.getCode().equals(newObj.getCode());
// 如果相等的情况下,需要设置下 id后续好更新
if (same) {
newObj.setId(oldObj.getId());
}
return same;
};
// 调用
List<List<Dog>> result = CollectionUtils.diffList(oldList, newList, sameFunc);
// 断言
assertEquals(result.size(), 3);
// 断言 create
assertEquals(result.get(0).size(), 1);
assertEquals(result.get(0).get(0), new Dog(null, "小白", "xb"));
// 断言 update
assertEquals(result.get(1).size(), 1);
assertEquals(result.get(1).get(0), new Dog(1, "花花2", "hh"));
// 断言 delete
assertEquals(result.get(2).size(), 1);
assertEquals(result.get(2).get(0), new Dog(2, "旺财", "wc"));
}
}

View File

@ -0,0 +1 @@
<http://www.iocoder.cn/Spring-Boot/Validation/?yudao>

View File

@ -0,0 +1,52 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>yudao-framework</artifactId>
<groupId>cn.iocoder.boot</groupId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>yudao-spring-boot-starter-biz-data-permission</artifactId>
<packaging>jar</packaging>
<name>${project.artifactId}</name>
<description>数据权限</description>
<url>https://github.com/YunaiV/ruoyi-vue-pro</url>
<dependencies>
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-common</artifactId>
</dependency>
<!-- Web 相关 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-security</artifactId>
<optional>true</optional> <!-- 可选,如果使用 DeptDataPermissionRule 必须提供 -->
</dependency>
<!-- DB 相关 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-mybatis</artifactId>
</dependency>
<!-- 业务组件 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-module-system-api</artifactId> <!-- 需要使用它,进行数据权限的获取 -->
<version>${revision}</version>
</dependency>
<!-- Test 测试相关 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.framework.datapermission.config;
import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionAnnotationAdvisor;
import cn.iocoder.yudao.framework.datapermission.core.db.DataPermissionDatabaseInterceptor;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactoryImpl;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.context.annotation.Bean;
import java.util.List;
/**
* 数据权限的自动配置类
*
* @author 芋道源码
*/
@AutoConfiguration
public class YudaoDataPermissionAutoConfiguration {
@Bean
public DataPermissionRuleFactory dataPermissionRuleFactory(List<DataPermissionRule> rules) {
return new DataPermissionRuleFactoryImpl(rules);
}
@Bean
public DataPermissionDatabaseInterceptor dataPermissionDatabaseInterceptor(MybatisPlusInterceptor interceptor,
DataPermissionRuleFactory ruleFactory) {
// 创建 DataPermissionDatabaseInterceptor 拦截器
DataPermissionDatabaseInterceptor inner = new DataPermissionDatabaseInterceptor(ruleFactory);
// 添加到 interceptor 中
// 需要加在首个,主要是为了在分页插件前面。这个是 MyBatis Plus 的规定
MyBatisUtils.addInterceptor(interceptor, inner, 0);
return inner;
}
@Bean
public DataPermissionAnnotationAdvisor dataPermissionAnnotationAdvisor() {
return new DataPermissionAnnotationAdvisor();
}
}

View File

@ -0,0 +1,34 @@
package cn.iocoder.yudao.framework.datapermission.config;
import cn.iocoder.yudao.framework.datapermission.core.rule.dept.DeptDataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.dept.DeptDataPermissionRuleCustomizer;
import cn.iocoder.yudao.framework.security.core.LoginUser;
import cn.iocoder.yudao.module.system.api.permission.PermissionApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.context.annotation.Bean;
import java.util.List;
/**
* 基于部门的数据权限 AutoConfiguration
*
* @author 芋道源码
*/
@AutoConfiguration
@ConditionalOnClass(LoginUser.class)
@ConditionalOnBean(value = {PermissionApi.class, DeptDataPermissionRuleCustomizer.class})
public class YudaoDeptDataPermissionAutoConfiguration {
@Bean
public DeptDataPermissionRule deptDataPermissionRule(PermissionApi permissionApi,
List<DeptDataPermissionRuleCustomizer> customizers) {
// 创建 DeptDataPermissionRule 对象
DeptDataPermissionRule rule = new DeptDataPermissionRule(permissionApi);
// 补全表配置
customizers.forEach(customizer -> customizer.customize(rule));
return rule;
}
}

View File

@ -0,0 +1,35 @@
package cn.iocoder.yudao.framework.datapermission.core.annotation;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import java.lang.annotation.*;
/**
* 数据权限注解
* 可声明在类或者方法上,标识使用的数据权限规则
*
* @author 芋道源码
*/
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataPermission {
/**
* 当前类或方法是否开启数据权限
* 即使不添加 @DataPermission 注解,默认是开启状态
* 可通过设置 enable 为 false 禁用
*/
boolean enable() default true;
/**
* 生效的数据权限规则数组,优先级高于 {@link #excludeRules()}
*/
Class<? extends DataPermissionRule>[] includeRules() default {};
/**
* 排除的数据权限规则数组,优先级最低
*/
Class<? extends DataPermissionRule>[] excludeRules() default {};
}

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.framework.datapermission.core.aop;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.aopalliance.aop.Advice;
import org.springframework.aop.Pointcut;
import org.springframework.aop.support.AbstractPointcutAdvisor;
import org.springframework.aop.support.ComposablePointcut;
import org.springframework.aop.support.annotation.AnnotationMatchingPointcut;
/**
* {@link cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission} 注解的 Advisor 实现类
*
* @author 芋道源码
*/
@Getter
@EqualsAndHashCode(callSuper = true)
public class DataPermissionAnnotationAdvisor extends AbstractPointcutAdvisor {
private final Advice advice;
private final Pointcut pointcut;
public DataPermissionAnnotationAdvisor() {
this.advice = new DataPermissionAnnotationInterceptor();
this.pointcut = this.buildPointcut();
}
protected Pointcut buildPointcut() {
Pointcut classPointcut = new AnnotationMatchingPointcut(DataPermission.class, true);
Pointcut methodPointcut = new AnnotationMatchingPointcut(null, DataPermission.class, true);
return new ComposablePointcut(classPointcut).union(methodPointcut);
}
}

View File

@ -0,0 +1,72 @@
package cn.iocoder.yudao.framework.datapermission.core.aop;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import lombok.Getter;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.core.MethodClassKey;
import org.springframework.core.annotation.AnnotationUtils;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* {@link DataPermission} 注解的拦截器
* 1. 在执行方法前,将 @DataPermission 注解入栈
* 2. 在执行方法后,将 @DataPermission 注解出栈
*
* @author 芋道源码
*/
@DataPermission // 该注解,用于 {@link DATA_PERMISSION_NULL} 的空对象
public class DataPermissionAnnotationInterceptor implements MethodInterceptor {
/**
* DataPermission 空对象,用于方法无 {@link DataPermission} 注解时,使用 DATA_PERMISSION_NULL 进行占位
*/
static final DataPermission DATA_PERMISSION_NULL = DataPermissionAnnotationInterceptor.class.getAnnotation(DataPermission.class);
@Getter
private final Map<MethodClassKey, DataPermission> dataPermissionCache = new ConcurrentHashMap<>();
@Override
public Object invoke(MethodInvocation methodInvocation) throws Throwable {
// 入栈
DataPermission dataPermission = this.findAnnotation(methodInvocation);
if (dataPermission != null) {
DataPermissionContextHolder.add(dataPermission);
}
try {
// 执行逻辑
return methodInvocation.proceed();
} finally {
// 出栈
if (dataPermission != null) {
DataPermissionContextHolder.remove();
}
}
}
private DataPermission findAnnotation(MethodInvocation methodInvocation) {
// 1. 从缓存中获取
Method method = methodInvocation.getMethod();
Object targetObject = methodInvocation.getThis();
Class<?> clazz = targetObject != null ? targetObject.getClass() : method.getDeclaringClass();
MethodClassKey methodClassKey = new MethodClassKey(method, clazz);
DataPermission dataPermission = dataPermissionCache.get(methodClassKey);
if (dataPermission != null) {
return dataPermission != DATA_PERMISSION_NULL ? dataPermission : null;
}
// 2.1 从方法中获取
dataPermission = AnnotationUtils.findAnnotation(method, DataPermission.class);
// 2.2 从类上获取
if (dataPermission == null) {
dataPermission = AnnotationUtils.findAnnotation(clazz, DataPermission.class);
}
// 2.3 添加到缓存中
dataPermissionCache.put(methodClassKey, dataPermission != null ? dataPermission : DATA_PERMISSION_NULL);
return dataPermission;
}
}

View File

@ -0,0 +1,72 @@
package cn.iocoder.yudao.framework.datapermission.core.aop;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import com.alibaba.ttl.TransmittableThreadLocal;
import java.util.LinkedList;
import java.util.List;
/**
* {@link DataPermission} 注解的 Context 上下文
*
* @author 芋道源码
*/
public class DataPermissionContextHolder {
/**
* 使用 List 的原因,可能存在方法的嵌套调用
*/
private static final ThreadLocal<LinkedList<DataPermission>> DATA_PERMISSIONS =
TransmittableThreadLocal.withInitial(LinkedList::new);
/**
* 获得当前的 DataPermission 注解
*
* @return DataPermission 注解
*/
public static DataPermission get() {
return DATA_PERMISSIONS.get().peekLast();
}
/**
* 入栈 DataPermission 注解
*
* @param dataPermission DataPermission 注解
*/
public static void add(DataPermission dataPermission) {
DATA_PERMISSIONS.get().addLast(dataPermission);
}
/**
* 出栈 DataPermission 注解
*
* @return DataPermission 注解
*/
public static DataPermission remove() {
DataPermission dataPermission = DATA_PERMISSIONS.get().removeLast();
// 无元素时,清空 ThreadLocal
if (DATA_PERMISSIONS.get().isEmpty()) {
DATA_PERMISSIONS.remove();
}
return dataPermission;
}
/**
* 获得所有 DataPermission
*
* @return DataPermission 队列
*/
public static List<DataPermission> getAll() {
return DATA_PERMISSIONS.get();
}
/**
* 清空上下文
*
* 目前仅仅用于单测
*/
public static void clear() {
DATA_PERMISSIONS.remove();
}
}

View File

@ -0,0 +1,641 @@
package cn.iocoder.yudao.framework.datapermission.core.db;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.Connection;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
* 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, List)} 方法
*
* 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。
* 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更!
*
* @author 芋道源码
*/
@RequiredArgsConstructor
public class DataPermissionDatabaseInterceptor extends JsqlParserSupport implements InnerInterceptor {
private final DataPermissionRuleFactory ruleFactory;
@Getter
private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
@Override // SELECT 场景
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
try {
// 初始化上下文
ContextHolder.init(rules);
// 处理 SQL
mpBs.sql(parserSingle(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
@Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景(因为 INSERT 不需要数据权限)
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
// 获得 Mapper 对应的数据权限的规则
List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
try {
// 初始化上下文
ContextHolder.init(rules);
// 处理 SQL
mpBs.sql(parserMulti(mpBs.sql(), null));
} finally {
// 添加是否需要重写的缓存
addMappedStatementCache(ms);
// 清空上下文
ContextHolder.clear();
}
}
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
processSelectBody(select.getSelectBody());
List<WithItem> withItemsList = select.getWithItemsList();
if (!CollectionUtils.isEmpty(withItemsList)) {
withItemsList.forEach(this::processSelectBody);
}
}
/**
* update 语句处理
*/
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
final Table table = update.getTable();
update.setWhere(this.builderExpression(update.getWhere(), table));
}
/**
* delete 语句处理
*/
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
}
// ========== 和 TenantLineInnerInterceptor 一致的逻辑 ==========
protected void processSelectBody(SelectBody selectBody) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
processSelectBody(withItem.getSubSelect().getSelectBody());
} else {
SetOperationList operationList = (SetOperationList) selectBody;
List<SelectBody> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
selectBodyList.forEach(this::processSelectBody);
}
}
}
/**
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
mainTables = processJoins(mainTables, joins);
}
// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}
private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}
/**
* 处理where条件内的子查询
* <p>
* 支持如下:
* 1. in
* 2. =
* 3. >
* 4. <
* 5. >=
* 6. <=
* 7. <>
* 8. EXISTS
* 9. NOT EXISTS
* <p>
* 前提条件:
* 1. 子查询必须放在小括号中
* 2. 子查询一般放在比较操作符的右边
*
* @param where where 条件
*/
protected void processWhereSubSelect(Expression where) {
if (where == null) {
return;
}
if (where instanceof FromItem) {
processOtherFromItem((FromItem) where);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
// 有子查询
if (where instanceof BinaryExpression) {
// 比较符号 , and , or , 等等
BinaryExpression expression = (BinaryExpression) where;
processWhereSubSelect(expression.getLeftExpression());
processWhereSubSelect(expression.getRightExpression());
} else if (where instanceof InExpression) {
// in
InExpression expression = (InExpression) where;
Expression inExpression = expression.getRightExpression();
if (inExpression instanceof SubSelect) {
processSelectBody(((SubSelect) inExpression).getSelectBody());
}
} else if (where instanceof ExistsExpression) {
// exists
ExistsExpression expression = (ExistsExpression) where;
processWhereSubSelect(expression.getRightExpression());
} else if (where instanceof NotExpression) {
// not exists
NotExpression expression = (NotExpression) where;
processWhereSubSelect(expression.getExpression());
} else if (where instanceof Parenthesis) {
Parenthesis expression = (Parenthesis) where;
processWhereSubSelect(expression.getExpression());
}
}
}
protected void processSelectItem(SelectItem selectItem) {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
if (selectExpressionItem.getExpression() instanceof SubSelect) {
processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
} else if (selectExpressionItem.getExpression() instanceof Function) {
processFunction((Function) selectExpressionItem.getExpression());
}
}
}
/**
* 处理函数
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p>
*
* @param function
*/
protected void processFunction(Function function) {
ExpressionList parameters = function.getParameters();
if (parameters != null) {
parameters.getExpressions().forEach(expression -> {
if (expression instanceof SubSelect) {
processSelectBody(((SubSelect) expression).getSelectBody());
} else if (expression instanceof Function) {
processFunction((Function) expression);
}
});
}
}
/**
* 处理子查询等
*/
protected void processOtherFromItem(FromItem fromItem) {
// 去除括号
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}
if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
} else if (fromItem instanceof ValuesList) {
logger.debug("Perform a subQuery, if you do not give us feedback");
} else if (fromItem instanceof LateralSubSelect) {
LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
if (lateralSubSelect.getSubSelect() != null) {
SubSelect subSelect = lateralSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
}
}
}
}
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}
/**
* 处理 joins
*
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables == null) {
mainTables = new ArrayList<>();
} else if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join需要记录下前面多个 on 的表名
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem joinItem = join.getRightItem();
// 获取当前 join 的表subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
onTables = Collections.singletonList(joinTable);
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue;
}
// 表名压栈,忽略的表压入 null以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
processOtherFromItem(joinItem);
leftTable = null;
}
}
return mainTables;
}
// ========== 和 TenantLineInnerInterceptor 存在差异的逻辑:关键,实现权限条件的拼接 ==========
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param table 单个表
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
return this.builderExpression(currentExpression, Collections.singletonList(table));
}
/**
* 处理条件
*
* @param currentExpression 当前 where 条件
* @param tables 多个表
*/
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 第一步,获得 Table 对应的数据权限条件
Expression dataPermissionExpression = null;
for (Table table : tables) {
// 构建每个表的权限 Expression 条件
Expression expression = buildDataPermissionExpression(table);
if (expression == null) {
continue;
}
// 合并到 dataPermissionExpression 中
dataPermissionExpression = dataPermissionExpression == null ? expression
: new AndExpression(dataPermissionExpression, expression);
}
// 第二步,合并多个 Expression 条件
if (dataPermissionExpression == null) {
return currentExpression;
}
if (currentExpression == null) {
return dataPermissionExpression;
}
// ① 如果表达式为 Or则需要 (currentExpression) AND dataPermissionExpression
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), dataPermissionExpression);
}
// ② 如果表达式为 And则直接返回 where AND dataPermissionExpression
return new AndExpression(currentExpression, dataPermissionExpression);
}
/**
* 构建指定表的数据权限的 Expression 过滤条件
*
* @param table 表
* @return Expression 过滤条件
*/
private Expression buildDataPermissionExpression(Table table) {
// 生成条件
Expression allExpression = null;
for (DataPermissionRule rule : ContextHolder.getRules()) {
// 判断表名是否匹配
String tableName = MyBatisUtils.getTableName(table);
if (!rule.getTableNames().contains(tableName)) {
continue;
}
// 如果有匹配的规则,说明可重写。
// 为什么不是有 allExpression 非空才重写呢?在生成 column = value 过滤条件时,会因为 value 不存在,导致未重写。
// 这样导致第一次无 value被标记成无需重写但是第二次有 value此时会需要重写。
ContextHolder.setRewrite(true);
// 单条规则的条件
Expression oneExpress = rule.getExpression(tableName, table.getAlias());
if (oneExpress == null){
continue;
}
// 拼接到 allExpression 中
allExpression = allExpression == null ? oneExpress
: new AndExpression(allExpression, oneExpress);
}
return allExpression;
}
/**
* 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
*
* @param ms MappedStatement
*/
private void addMappedStatementCache(MappedStatement ms) {
if (ContextHolder.getRewrite()) {
return;
}
// 无重写,进行添加
mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
}
/**
* SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
*
* @author 芋道源码
*/
static final class ContextHolder {
/**
* 该 {@link MappedStatement} 对应的规则
*/
private static final ThreadLocal<List<DataPermissionRule>> RULES = ThreadLocal.withInitial(Collections::emptyList);
/**
* SQL 是否进行重写
*/
private static final ThreadLocal<Boolean> REWRITE = ThreadLocal.withInitial(() -> Boolean.FALSE);
public static void init(List<DataPermissionRule> rules) {
RULES.set(rules);
REWRITE.set(false);
}
public static void clear() {
RULES.remove();
REWRITE.remove();
}
public static boolean getRewrite() {
return REWRITE.get();
}
public static void setRewrite(boolean rewrite) {
REWRITE.set(rewrite);
}
public static List<DataPermissionRule> getRules() {
return RULES.get();
}
}
/**
* {@link MappedStatement} 缓存
* 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
* 如果无效,则可以避免 SQL 的解析,加快速度
*
* @author 芋道源码
*/
static final class MappedStatementCache {
/**
* 指定数据权限规则,对指定 MappedStatement 无需重写(不生效)的缓存
*
* value{@link MappedStatement#getId()} 编号
*/
@Getter
private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<>();
/**
* 判断是否无需重写
* ps虽然有点中文式英语但是容易读懂即可
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
* @return 是否无需重写
*/
public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
// 如果规则为空,说明无需重写
if (CollUtil.isEmpty(rules)) {
return true;
}
// 任一规则不在 noRewritableMap 中,则说明可能需要重写
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (!CollUtil.contains(mappedStatementIds, ms.getId())) {
return false;
}
}
return true;
}
/**
* 添加无需重写的 MappedStatement
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
*/
public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
for (DataPermissionRule rule : rules) {
Set<String> mappedStatementIds = noRewritableMappedStatements.get(rule.getClass());
if (CollUtil.isNotEmpty(mappedStatementIds)) {
mappedStatementIds.add(ms.getId());
} else {
noRewritableMappedStatements.put(rule.getClass(), SetUtils.asSet(ms.getId()));
}
}
}
/**
* 清空缓存
* 目前主要提供给单元测试
*/
public void clear() {
noRewritableMappedStatements.clear();
}
}
}

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.framework.datapermission.core.rule;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import java.util.Set;
/**
* 数据权限规则接口
* 通过实现接口,自定义数据规则。例如说,
*
* @author 芋道源码
*/
public interface DataPermissionRule {
/**
* 返回需要生效的表名数组
* 为什么需要该方法Data Permission 数组基于 SQL 重写,通过 Where 返回只有权限的数据
*
* 如果需要基于实体名获得表名,可调用 {@link TableInfoHelper#getTableInfo(Class)} 获得
*
* @return 表名数组
*/
Set<String> getTableNames();
/**
* 根据表名和别名,生成对应的 WHERE / OR 过滤条件
*
* @param tableName 表名
* @param tableAlias 别名,可能为空
* @return 过滤条件 Expression 表达式
*/
Expression getExpression(String tableName, Alias tableAlias);
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.framework.datapermission.core.rule;
import java.util.List;
/**
* {@link DataPermissionRule} 工厂接口
* 作为 {@link DataPermissionRule} 的容器,提供管理能力
*
* @author 芋道源码
*/
public interface DataPermissionRuleFactory {
/**
* 获得所有数据权限规则数组
*
* @return 数据权限规则数组
*/
List<DataPermissionRule> getDataPermissionRules();
/**
* 获得指定 Mapper 的数据权限规则数组
*
* @param mappedStatementId 指定 Mapper 的编号
* @return 数据权限规则数组
*/
List<DataPermissionRule> getDataPermissionRule(String mappedStatementId);
}

View File

@ -0,0 +1,62 @@
package cn.iocoder.yudao.framework.datapermission.core.rule;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionContextHolder;
import lombok.RequiredArgsConstructor;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
/**
* 默认的 DataPermissionRuleFactoryImpl 实现类
* 支持通过 {@link DataPermissionContextHolder} 过滤数据权限
*
* @author 芋道源码
*/
@RequiredArgsConstructor
public class DataPermissionRuleFactoryImpl implements DataPermissionRuleFactory {
/**
* 数据权限规则数组
*/
private final List<DataPermissionRule> rules;
@Override
public List<DataPermissionRule> getDataPermissionRules() {
return rules;
}
@Override // mappedStatementId 参数,暂时没有用。以后,可以基于 mappedStatementId + DataPermission 进行缓存
public List<DataPermissionRule> getDataPermissionRule(String mappedStatementId) {
// 1. 无数据权限
if (CollUtil.isEmpty(rules)) {
return Collections.emptyList();
}
// 2. 未配置,则默认开启
DataPermission dataPermission = DataPermissionContextHolder.get();
if (dataPermission == null) {
return rules;
}
// 3. 已配置,但禁用
if (!dataPermission.enable()) {
return Collections.emptyList();
}
// 4. 已配置,只选择部分规则
if (ArrayUtil.isNotEmpty(dataPermission.includeRules())) {
return rules.stream().filter(rule -> ArrayUtil.contains(dataPermission.includeRules(), rule.getClass()))
.collect(Collectors.toList()); // 一般规则不会太多,所以不采用 HashSet 查询
}
// 5. 已配置,只排除部分规则
if (ArrayUtil.isNotEmpty(dataPermission.excludeRules())) {
return rules.stream().filter(rule -> !ArrayUtil.contains(dataPermission.excludeRules(), rule.getClass()))
.collect(Collectors.toList()); // 一般规则不会太多,所以不采用 HashSet 查询
}
// 6. 已配置,全部规则
return rules;
}
}

View File

@ -0,0 +1,205 @@
package cn.iocoder.yudao.framework.datapermission.core.rule.dept;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import cn.iocoder.yudao.framework.security.core.LoginUser;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.system.api.permission.PermissionApi;
import cn.iocoder.yudao.module.system.api.permission.dto.DeptDataPermissionRespDTO;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* 基于部门的 {@link DataPermissionRule} 数据权限规则实现
*
* 注意,使用 DeptDataPermissionRule 时,需要保证表中有 dept_id 部门编号的字段,可自定义。
*
* 实际业务场景下,会存在一个经典的问题?当用户修改部门时,冗余的 dept_id 是否需要修改?
* 1. 一般情况下dept_id 不进行修改则会导致用户看不到之前的数据。【yudao-server 采用该方案】
* 2. 部分情况下,希望该用户还是能看到之前的数据,则有两种方式解决:【需要你改造该 DeptDataPermissionRule 的实现代码】
* 1编写洗数据的脚本将 dept_id 修改成新部门的编号;【建议】
* 最终过滤条件是 WHERE dept_id = ?
* 2洗数据的话可能涉及的数据量较大也可以采用 user_id 进行过滤的方式,此时需要获取到 dept_id 对应的所有 user_id 用户编号;
* 最终过滤条件是 WHERE user_id IN (?, ?, ? ...)
* 3想要保证原 dept_id 和 user_id 都可以看的到,此时使用 dept_id 和 user_id 一起过滤;
* 最终过滤条件是 WHERE dept_id = ? OR user_id IN (?, ?, ? ...)
*
* @author 芋道源码
*/
@AllArgsConstructor
@Slf4j
public class DeptDataPermissionRule implements DataPermissionRule {
/**
* LoginUser 的 Context 缓存 Key
*/
protected static final String CONTEXT_KEY = DeptDataPermissionRule.class.getSimpleName();
private static final String DEPT_COLUMN_NAME = "dept_id";
private static final String USER_COLUMN_NAME = "user_id";
static final Expression EXPRESSION_NULL = new NullValue();
private final PermissionApi permissionApi;
/**
* 基于部门的表字段配置
* 一般情况下,每个表的部门编号字段是 dept_id通过该配置自定义。
*
* key表名
* value字段名
*/
private final Map<String, String> deptColumns = new HashMap<>();
/**
* 基于用户的表字段配置
* 一般情况下,每个表的部门编号字段是 dept_id通过该配置自定义。
*
* key表名
* value字段名
*/
private final Map<String, String> userColumns = new HashMap<>();
/**
* 所有表名,是 {@link #deptColumns} 和 {@link #userColumns} 的合集
*/
private final Set<String> TABLE_NAMES = new HashSet<>();
@Override
public Set<String> getTableNames() {
return TABLE_NAMES;
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
// 只有有登陆用户的情况下,才进行数据权限的处理
LoginUser loginUser = SecurityFrameworkUtils.getLoginUser();
if (loginUser == null) {
return null;
}
// 只有管理员类型的用户,才进行数据权限的处理
if (ObjectUtil.notEqual(loginUser.getUserType(), UserTypeEnum.ADMIN.getValue())) {
return null;
}
// 获得数据权限
DeptDataPermissionRespDTO deptDataPermission = loginUser.getContext(CONTEXT_KEY, DeptDataPermissionRespDTO.class);
// 从上下文中拿不到,则调用逻辑进行获取
if (deptDataPermission == null) {
deptDataPermission = permissionApi.getDeptDataPermission(loginUser.getId());
if (deptDataPermission == null) {
log.error("[getExpression][LoginUser({}) 获取数据权限为 null]", JsonUtils.toJsonString(loginUser));
throw new NullPointerException(String.format("LoginUser(%d) Table(%s/%s) 未返回数据权限",
loginUser.getId(), tableName, tableAlias.getName()));
}
// 添加到上下文中,避免重复计算
loginUser.setContext(CONTEXT_KEY, deptDataPermission);
}
// 情况一,如果是 ALL 可查看全部,则无需拼接条件
if (deptDataPermission.getAll()) {
return null;
}
// 情况二,即不能查看部门,又不能查看自己,则说明 100% 无权限
if (CollUtil.isEmpty(deptDataPermission.getDeptIds())
&& Boolean.FALSE.equals(deptDataPermission.getSelf())) {
return new EqualsTo(null, null); // WHERE null = null可以保证返回的数据为空
}
// 情况三,拼接 Dept 和 User 的条件,最后组合
Expression deptExpression = buildDeptExpression(tableName,tableAlias, deptDataPermission.getDeptIds());
Expression userExpression = buildUserExpression(tableName, tableAlias, deptDataPermission.getSelf(), loginUser.getId());
if (deptExpression == null && userExpression == null) {
// TODO 芋艿:获得不到条件的时候,暂时不抛出异常,而是不返回数据
log.warn("[getExpression][LoginUser({}) Table({}/{}) DeptDataPermission({}) 构建的条件为空]",
JsonUtils.toJsonString(loginUser), tableName, tableAlias, JsonUtils.toJsonString(deptDataPermission));
// throw new NullPointerException(String.format("LoginUser(%d) Table(%s/%s) 构建的条件为空",
// loginUser.getId(), tableName, tableAlias.getName()));
return EXPRESSION_NULL;
}
if (deptExpression == null) {
return userExpression;
}
if (userExpression == null) {
return deptExpression;
}
// 目前,如果有指定部门 + 可查看自己,采用 OR 条件。即WHERE (dept_id IN ? OR user_id = ?)
return new Parenthesis(new OrExpression(deptExpression, userExpression));
}
private Expression buildDeptExpression(String tableName, Alias tableAlias, Set<Long> deptIds) {
// 如果不存在配置,则无需作为条件
String columnName = deptColumns.get(tableName);
if (StrUtil.isEmpty(columnName)) {
return null;
}
// 如果为空,则无条件
if (CollUtil.isEmpty(deptIds)) {
return null;
}
// 拼接条件
return new InExpression(MyBatisUtils.buildColumn(tableName, tableAlias, columnName),
new ExpressionList(CollectionUtils.convertList(deptIds, LongValue::new)));
}
private Expression buildUserExpression(String tableName, Alias tableAlias, Boolean self, Long userId) {
// 如果不查看自己,则无需作为条件
if (Boolean.FALSE.equals(self)) {
return null;
}
String columnName = userColumns.get(tableName);
if (StrUtil.isEmpty(columnName)) {
return null;
}
// 拼接条件
return new EqualsTo(MyBatisUtils.buildColumn(tableName, tableAlias, columnName), new LongValue(userId));
}
// ==================== 添加配置 ====================
public void addDeptColumn(Class<? extends BaseDO> entityClass) {
addDeptColumn(entityClass, DEPT_COLUMN_NAME);
}
public void addDeptColumn(Class<? extends BaseDO> entityClass, String columnName) {
String tableName = TableInfoHelper.getTableInfo(entityClass).getTableName();
addDeptColumn(tableName, columnName);
}
public void addDeptColumn(String tableName, String columnName) {
deptColumns.put(tableName, columnName);
TABLE_NAMES.add(tableName);
}
public void addUserColumn(Class<? extends BaseDO> entityClass) {
addUserColumn(entityClass, USER_COLUMN_NAME);
}
public void addUserColumn(Class<? extends BaseDO> entityClass, String columnName) {
String tableName = TableInfoHelper.getTableInfo(entityClass).getTableName();
addUserColumn(tableName, columnName);
}
public void addUserColumn(String tableName, String columnName) {
userColumns.put(tableName, columnName);
TABLE_NAMES.add(tableName);
}
}

View File

@ -0,0 +1,20 @@
package cn.iocoder.yudao.framework.datapermission.core.rule.dept;
/**
* {@link DeptDataPermissionRule} 的自定义配置接口
*
* @author 芋道源码
*/
@FunctionalInterface
public interface DeptDataPermissionRuleCustomizer {
/**
* 自定义该权限规则
* 1. 调用 {@link DeptDataPermissionRule#addDeptColumn(Class, String)} 方法,配置基于 dept_id 的过滤规则
* 2. 调用 {@link DeptDataPermissionRule#addUserColumn(Class, String)} 方法,配置基于 user_id 的过滤规则
*
* @param rule 权限规则
*/
void customize(DeptDataPermissionRule rule);
}

View File

@ -0,0 +1,6 @@
/**
* 基于部门的数据权限规则
*
* @author 芋道源码
*/
package cn.iocoder.yudao.framework.datapermission.core.rule.dept;

View File

@ -0,0 +1,43 @@
package cn.iocoder.yudao.framework.datapermission.core.util;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionContextHolder;
import lombok.SneakyThrows;
/**
* 数据权限 Util
*
* @author 芋道源码
*/
public class DataPermissionUtils {
private static DataPermission DATA_PERMISSION_DISABLE;
@DataPermission(enable = false)
@SneakyThrows
private static DataPermission getDisableDataPermissionDisable() {
if (DATA_PERMISSION_DISABLE == null) {
DATA_PERMISSION_DISABLE = DataPermissionUtils.class
.getDeclaredMethod("getDisableDataPermissionDisable")
.getAnnotation(DataPermission.class);
}
return DATA_PERMISSION_DISABLE;
}
/**
* 忽略数据权限,执行对应的逻辑
*
* @param runnable 逻辑
*/
public static void executeIgnore(Runnable runnable) {
DataPermission dataPermission = getDisableDataPermissionDisable();
DataPermissionContextHolder.add(dataPermission);
try {
// 执行 runnable
runnable.run();
} finally {
DataPermissionContextHolder.remove();
}
}
}

View File

@ -0,0 +1,4 @@
/**
* 基于 JSqlParser 解析 SQL增加数据权限的 WHERE 条件
*/
package cn.iocoder.yudao.framework.datapermission;

View File

@ -0,0 +1,2 @@
cn.iocoder.yudao.framework.datapermission.config.YudaoDataPermissionAutoConfiguration
cn.iocoder.yudao.framework.datapermission.config.YudaoDeptDataPermissionAutoConfiguration

View File

@ -0,0 +1,108 @@
package cn.iocoder.yudao.framework.datapermission.core.aop;
import cn.hutool.core.collection.CollUtil;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import org.aopalliance.intercept.MethodInvocation;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import java.lang.reflect.Method;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.when;
/**
* {@link DataPermissionAnnotationInterceptor} 的单元测试
*
* @author 芋道源码
*/
public class DataPermissionAnnotationInterceptorTest extends BaseMockitoUnitTest {
@InjectMocks
private DataPermissionAnnotationInterceptor interceptor;
@Mock
private MethodInvocation methodInvocation;
@BeforeEach
public void setUp() {
interceptor.getDataPermissionCache().clear();
}
@Test // 无 @DataPermission 注解
public void testInvoke_none() throws Throwable {
// 参数
mockMethodInvocation(TestNone.class);
// 调用
Object result = interceptor.invoke(methodInvocation);
// 断言
assertEquals("none", result);
assertEquals(1, interceptor.getDataPermissionCache().size());
assertTrue(CollUtil.getFirst(interceptor.getDataPermissionCache().values()).enable());
}
@Test // 在 Method 上有 @DataPermission 注解
public void testInvoke_method() throws Throwable {
// 参数
mockMethodInvocation(TestMethod.class);
// 调用
Object result = interceptor.invoke(methodInvocation);
// 断言
assertEquals("method", result);
assertEquals(1, interceptor.getDataPermissionCache().size());
assertFalse(CollUtil.getFirst(interceptor.getDataPermissionCache().values()).enable());
}
@Test // 在 Class 上有 @DataPermission 注解
public void testInvoke_class() throws Throwable {
// 参数
mockMethodInvocation(TestClass.class);
// 调用
Object result = interceptor.invoke(methodInvocation);
// 断言
assertEquals("class", result);
assertEquals(1, interceptor.getDataPermissionCache().size());
assertFalse(CollUtil.getFirst(interceptor.getDataPermissionCache().values()).enable());
}
private void mockMethodInvocation(Class<?> clazz) throws Throwable {
Object targetObject = clazz.newInstance();
Method method = targetObject.getClass().getMethod("echo");
when(methodInvocation.getThis()).thenReturn(targetObject);
when(methodInvocation.getMethod()).thenReturn(method);
when(methodInvocation.proceed()).then(invocationOnMock -> method.invoke(targetObject));
}
static class TestMethod {
@DataPermission(enable = false)
public String echo() {
return "method";
}
}
@DataPermission(enable = false)
static class TestClass {
public String echo() {
return "class";
}
}
static class TestNone {
public String echo() {
return "none";
}
}
}

View File

@ -0,0 +1,66 @@
package cn.iocoder.yudao.framework.datapermission.core.aop;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
/**
* {@link DataPermissionContextHolder} 的单元测试
*
* @author 芋道源码
*/
class DataPermissionContextHolderTest {
@BeforeEach
public void setUp() {
DataPermissionContextHolder.clear();
}
@Test
public void testGet() {
// mock 方法
DataPermission dataPermission01 = mock(DataPermission.class);
DataPermissionContextHolder.add(dataPermission01);
DataPermission dataPermission02 = mock(DataPermission.class);
DataPermissionContextHolder.add(dataPermission02);
// 调用
DataPermission result = DataPermissionContextHolder.get();
// 断言
assertSame(result, dataPermission02);
}
@Test
public void testPush() {
// 调用
DataPermission dataPermission01 = mock(DataPermission.class);
DataPermissionContextHolder.add(dataPermission01);
DataPermission dataPermission02 = mock(DataPermission.class);
DataPermissionContextHolder.add(dataPermission02);
// 断言
DataPermission first = DataPermissionContextHolder.getAll().get(0);
DataPermission second = DataPermissionContextHolder.getAll().get(1);
assertSame(dataPermission01, first);
assertSame(dataPermission02, second);
}
@Test
public void testRemove() {
// mock 方法
DataPermission dataPermission01 = mock(DataPermission.class);
DataPermissionContextHolder.add(dataPermission01);
DataPermission dataPermission02 = mock(DataPermission.class);
DataPermissionContextHolder.add(dataPermission02);
// 调用
DataPermission result = DataPermissionContextHolder.remove();
// 断言
assertSame(result, dataPermission02);
assertEquals(1, DataPermissionContextHolder.getAll().size());
}
}

View File

@ -0,0 +1,190 @@
package cn.iocoder.yudao.framework.datapermission.core.db;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import java.sql.Connection;
import java.util.*;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
/**
* {@link DataPermissionDatabaseInterceptor} 的单元测试
* 主要测试 {@link DataPermissionDatabaseInterceptor#beforePrepare(StatementHandler, Connection, Integer)}
* 和 {@link DataPermissionDatabaseInterceptor#beforeUpdate(Executor, MappedStatement, Object)}
* 以及在这个过程中ContextHolder 和 MappedStatementCache
*
* @author 芋道源码
*/
public class DataPermissionDatabaseInterceptorTest extends BaseMockitoUnitTest {
@InjectMocks
private DataPermissionDatabaseInterceptor interceptor;
@Mock
private DataPermissionRuleFactory ruleFactory;
@BeforeEach
public void setUp() {
// 清理上下文
DataPermissionDatabaseInterceptor.ContextHolder.clear();
// 清空缓存
interceptor.getMappedStatementCache().clear();
}
@Test // 不存在规则,且不匹配
public void testBeforeQuery_withoutRule() {
try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
// 准备参数
MappedStatement mappedStatement = mock(MappedStatement.class);
BoundSql boundSql = mock(BoundSql.class);
// 调用
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
pluginUtilsMock.verify(() -> PluginUtils.mpBoundSql(boundSql), never());
}
}
@Test // 存在规则,且不匹配
public void testBeforeQuery_withMatchRule() {
try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
// 准备参数
MappedStatement mappedStatement = mock(MappedStatement.class);
BoundSql boundSql = mock(BoundSql.class);
// mock 方法(数据权限)
when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
.thenReturn(singletonList(new DeptDataPermissionRule()));
// mock 方法(MPBoundSql)
PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
// mock 方法(SQL)
String sql = "select * from t_user where id = 1";
when(mpBs.sql()).thenReturn(sql);
// 针对 ContextHolder 和 MappedStatementCache 暂时不 mock主要想校验过程中数据是否正确
// 调用
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
verify(mpBs, times(1)).sql(
eq("SELECT * FROM t_user WHERE id = 1 AND t_user.dept_id = 100"));
// 断言缓存
assertTrue(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
}
}
@Test // 存在规则,但不匹配
public void testBeforeQuery_withoutMatchRule() {
try (MockedStatic<PluginUtils> pluginUtilsMock = mockStatic(PluginUtils.class)) {
// 准备参数
MappedStatement mappedStatement = mock(MappedStatement.class);
BoundSql boundSql = mock(BoundSql.class);
// mock 方法(数据权限)
when(ruleFactory.getDataPermissionRule(same(mappedStatement.getId())))
.thenReturn(singletonList(new DeptDataPermissionRule()));
// mock 方法(MPBoundSql)
PluginUtils.MPBoundSql mpBs = mock(PluginUtils.MPBoundSql.class);
pluginUtilsMock.when(() -> PluginUtils.mpBoundSql(same(boundSql))).thenReturn(mpBs);
// mock 方法(SQL)
String sql = "select * from t_role where id = 1";
when(mpBs.sql()).thenReturn(sql);
// 针对 ContextHolder 和 MappedStatementCache 暂时不 mock主要想校验过程中数据是否正确
// 调用
interceptor.beforeQuery(null, mappedStatement, null, null, null, boundSql);
// 断言
verify(mpBs, times(1)).sql(
eq("SELECT * FROM t_role WHERE id = 1"));
// 断言缓存
assertFalse(interceptor.getMappedStatementCache().getNoRewritableMappedStatements().isEmpty());
}
}
@Test
public void testAddNoRewritable() {
// 准备参数
MappedStatement ms = mock(MappedStatement.class);
List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
// mock 方法
when(ms.getId()).thenReturn("selectById");
// 调用
interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
// 断言
Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements =
interceptor.getMappedStatementCache().getNoRewritableMappedStatements();
assertEquals(1, noRewritableMappedStatements.size());
assertEquals(SetUtils.asSet("selectById"), noRewritableMappedStatements.get(DeptDataPermissionRule.class));
}
@Test
public void testNoRewritable() {
// 准备参数
MappedStatement ms = mock(MappedStatement.class);
// mock 方法
when(ms.getId()).thenReturn("selectById");
// mock 数据
List<DataPermissionRule> rules = singletonList(new DeptDataPermissionRule());
interceptor.getMappedStatementCache().addNoRewritable(ms, rules);
// 场景一rules 为空
assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, null));
// 场景二rules 非空,可重写
assertFalse(interceptor.getMappedStatementCache().noRewritable(ms, singletonList(new EmptyDataPermissionRule())));
// 场景三rule 非空,不可重写
assertTrue(interceptor.getMappedStatementCache().noRewritable(ms, rules));
}
private static class DeptDataPermissionRule implements DataPermissionRule {
private static final String COLUMN = "dept_id";
@Override
public Set<String> getTableNames() {
return SetUtils.asSet("t_user");
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
LongValue value = new LongValue(100L);
return new EqualsTo(column, value);
}
}
private static class EmptyDataPermissionRule implements DataPermissionRule {
@Override
public Set<String> getTableNames() {
return Collections.emptySet();
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
return null;
}
}
}

View File

@ -0,0 +1,533 @@
package cn.iocoder.yudao.framework.datapermission.core.db;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Column;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import java.util.Arrays;
import java.util.Set;
import static cn.iocoder.yudao.framework.common.util.collection.SetUtils.asSet;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* {@link DataPermissionDatabaseInterceptor} 的单元测试
* 主要复用了 MyBatis Plus 的 TenantLineInnerInterceptorTest 的单元测试
* 不过它的单元测试不是很规范,考虑到是复用的,所以暂时不进行修改~
*
* @author 芋道源码
*/
public class DataPermissionDatabaseInterceptorTest2 extends BaseMockitoUnitTest {
@InjectMocks
private DataPermissionDatabaseInterceptor interceptor;
@Mock
private DataPermissionRuleFactory ruleFactory;
@BeforeEach
public void setUp() {
// 租户的数据权限规则
DataPermissionRule tenantRule = new DataPermissionRule() {
private static final String COLUMN = "tenant_id";
@Override
public Set<String> getTableNames() {
return asSet("entity", "entity1", "entity2", "entity3", "t1", "t2", "sys_dict_item", // 支持 MyBatis Plus 的单元测试
"t_user", "t_role"); // 满足自己的单元测试
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
LongValue value = new LongValue(1L);
return new EqualsTo(column, value);
}
};
// 部门的数据权限规则
DataPermissionRule deptRule = new DataPermissionRule() {
private static final String COLUMN = "dept_id";
@Override
public Set<String> getTableNames() {
return asSet("t_user"); // 满足自己的单元测试
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
Column column = MyBatisUtils.buildColumn(tableName, tableAlias, COLUMN);
ExpressionList values = new ExpressionList(new LongValue(10L),
new LongValue(20L));
return new InExpression(column, values);
}
};
// 设置到上下文,保证
DataPermissionDatabaseInterceptor.ContextHolder.init(Arrays.asList(tenantRule, deptRule));
}
@Test
void delete() {
assertSql("delete from entity where id = ?",
"DELETE FROM entity WHERE id = ? AND entity.tenant_id = 1");
}
@Test
void update() {
assertSql("update entity set name = ? where id = ?",
"UPDATE entity SET name = ? WHERE id = ? AND entity.tenant_id = 1");
}
@Test
void selectSingle() {
// 单表
assertSql("select * from entity where id = ?",
"SELECT * FROM entity WHERE id = ? AND entity.tenant_id = 1");
assertSql("select * from entity where id = ? or name = ?",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
/* not */
assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
"SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND entity.tenant_id = 1");
}
@Test
void selectSubSelectIn() {
/* in */
assertSql("SELECT * FROM entity e WHERE e.id IN (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id IN (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
// 在最前
assertSql("SELECT * FROM entity e WHERE e.id IN " +
"(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
"SELECT * FROM entity e WHERE e.id IN " +
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
// 在最后
assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
"(select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
// 在中间
assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
"(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
"SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
}
@Test
void selectSubSelectEq() {
/* = */
assertSql("SELECT * FROM entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
}
@Test
void selectSubSelectInnerNotEq() {
/* inner not = */
assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?))",
"SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1)) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?) and e.id = ?)",
"SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ?) AND e.tenant_id = 1");
}
@Test
void selectSubSelectExists() {
/* EXISTS */
assertSql("SELECT * FROM entity e WHERE EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* NOT EXISTS */
assertSql("SELECT * FROM entity e WHERE NOT EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE NOT EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
}
@Test
void selectSubSelect() {
/* >= */
assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <= */
assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <> */
assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
}
@Test
void selectFromSelect() {
assertSql("SELECT * FROM (select e.id from entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?))",
"SELECT * FROM (SELECT e.id FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1)");
}
@Test
void selectBodySubSelect() {
assertSql("select t1.col1,(select t2.col2 from t2 t2 where t1.col1=t2.col1) from t1 t1",
"SELECT t1.col1, (SELECT t2.col2 FROM t2 t2 WHERE t1.col1 = t2.col1 AND t2.tenant_id = 1) FROM t1 t1 WHERE t1.tenant_id = 1");
}
@Test
void selectLeftJoin() {
// left join
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
}
@Test
void selectRightJoin() {
// right join
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM with_as_1 e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM with_as_1 e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id ",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
}
@Test
void selectMixJoin() {
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"inner join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1");
}
@Test
void selectJoinSubSelect() {
assertSql("select * from (select * from entity) e1 " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM (SELECT * FROM entity WHERE entity.tenant_id = 1) e1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
assertSql("select * from entity1 e1 " +
"left join (select * from entity2) e2 " +
"on e1.id = e2.id",
"SELECT * FROM entity1 e1 " +
"LEFT JOIN (SELECT * FROM entity2 WHERE entity2.tenant_id = 1) e2 " +
"ON e1.id = e2.id " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectSubJoin() {
assertSql("select * FROM " +
"(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"WHERE e2.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"WHERE e1.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " +
"right join entity3 e3 on e1.id = e3.id",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " +
"WHERE e3.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"ON e.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectLeftJoinMultipleTrailingOn() {
// 多个 on 尾缀的
assertSql("SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN entity2 e2 ON e2.id = e1.id " +
"ON e1.id = e.id " +
"WHERE (e.id = ? OR e.NAME = ?)",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN entity2 e2 ON e2.id = e1.id AND e2.tenant_id = 1 " +
"ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
"ON e1.id = e.id " +
"WHERE (e.id = ? OR e.NAME = ?)",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
"ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
}
@Test
void selectInnerJoin() {
// inner join
assertSql("SELECT * FROM entity e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE e.id = ? OR e.name = ?");
assertSql("SELECT * FROM entity e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?)");
// 隐式内连接
assertSql("SELECT * FROM entity,entity1 " +
"WHERE entity.id = entity1.id",
"SELECT * FROM entity, entity1 " +
"WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
// 隐式内连接
assertSql("SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id AND a.tenant_id = 1");
assertSql("SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id");
// SubJoin with 隐式内连接
assertSql("SELECT * FROM (entity,entity1) " +
"WHERE entity.id = entity1.id",
"SELECT * FROM (entity, entity1) " +
"WHERE entity.id = entity1.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
assertSql("SELECT * FROM ((entity,entity1),entity2) " +
"WHERE entity.id = entity1.id and entity.id = entity2.id",
"SELECT * FROM ((entity, entity1), entity2) " +
"WHERE entity.id = entity1.id AND entity.id = entity2.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
assertSql("SELECT * FROM (entity,(entity1,entity2)) " +
"WHERE entity.id = entity1.id and entity.id = entity2.id",
"SELECT * FROM (entity, (entity1, entity2)) " +
"WHERE entity.id = entity1.id AND entity.id = entity2.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
// 沙雕的括号写法
assertSql("SELECT * FROM (((entity,entity1))) " +
"WHERE entity.id = entity1.id",
"SELECT * FROM (((entity, entity1))) " +
"WHERE entity.id = entity1.id " +
"AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
}
@Test
void selectWithAs() {
assertSql("with with_as_A as (select * from entity) select * from with_as_A",
"WITH with_as_A AS (SELECT * FROM entity WHERE entity.tenant_id = 1) SELECT * FROM with_as_A");
}
@Test
void selectIgnoreTable() {
assertSql(" SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)",
"SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id AND item.tenant_id = 1 WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)");
}
private void assertSql(String sql, String targetSql) {
assertEquals(targetSql, interceptor.parserSingle(sql, null));
}
// ========== 额外的测试 ==========
@Test
public void testSelectSingle() {
// 单表
assertSql("select * from t_user where id = ?",
"SELECT * FROM t_user WHERE id = ? AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
assertSql("select * from t_user where id = ? or name = ?",
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
assertSql("SELECT * FROM t_user WHERE (id = ? OR name = ?)",
"SELECT * FROM t_user WHERE (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
/* not */
assertSql("SELECT * FROM t_user WHERE not (id = ? OR name = ?)",
"SELECT * FROM t_user WHERE NOT (id = ? OR name = ?) AND t_user.tenant_id = 1 AND t_user.dept_id IN (10, 20)");
}
@Test
public void testSelectLeftJoin() {
// left join
assertSql("SELECT * FROM t_user e " +
"left join t_role e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " +
"LEFT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)");
// 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " +
"left join t_role e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " +
"LEFT JOIN t_role e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e.dept_id IN (10, 20)");
}
@Test
public void testSelectRightJoin() {
// right join
assertSql("SELECT * FROM t_user e " +
"right join t_role e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " +
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
// 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " +
"right join t_role e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " +
"RIGHT JOIN t_role e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
}
@Test
public void testSelectInnerJoin() {
// inner join
assertSql("SELECT * FROM t_user e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM t_user e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " +
"WHERE e.id = ? OR e.name = ?");
// 条件 e.id = ? OR e.name = ? 带括号
assertSql("SELECT * FROM t_user e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM t_user e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e.dept_id IN (10, 20) AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?)");
// 没有 On 的 inner join
assertSql("SELECT * FROM entity,entity1 " +
"WHERE entity.id = entity1.id",
"SELECT * FROM entity, entity1 " +
"WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
}
}

View File

@ -0,0 +1,145 @@
package cn.iocoder.yudao.framework.datapermission.core.rule;
import cn.iocoder.yudao.framework.datapermission.core.annotation.DataPermission;
import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionContextHolder;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Spy;
import org.springframework.core.annotation.AnnotationUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomString;
import static org.junit.jupiter.api.Assertions.*;
/**
* {@link DataPermissionRuleFactoryImpl} 单元测试
*
* @author 芋道源码
*/
class DataPermissionRuleFactoryImplTest extends BaseMockitoUnitTest {
@InjectMocks
private DataPermissionRuleFactoryImpl dataPermissionRuleFactory;
@Spy
private List<DataPermissionRule> rules = Arrays.asList(new DataPermissionRule01(),
new DataPermissionRule02());
@BeforeEach
public void setUp() {
DataPermissionContextHolder.clear();
}
@Test
public void testGetDataPermissionRule_02() {
// 准备参数
String mappedStatementId = randomString();
// 调用
List<DataPermissionRule> result = dataPermissionRuleFactory.getDataPermissionRule(mappedStatementId);
// 断言
assertSame(rules, result);
}
@Test
public void testGetDataPermissionRule_03() {
// 准备参数
String mappedStatementId = randomString();
// mock 方法
DataPermissionContextHolder.add(AnnotationUtils.findAnnotation(TestClass03.class, DataPermission.class));
// 调用
List<DataPermissionRule> result = dataPermissionRuleFactory.getDataPermissionRule(mappedStatementId);
// 断言
assertTrue(result.isEmpty());
}
@Test
public void testGetDataPermissionRule_04() {
// 准备参数
String mappedStatementId = randomString();
// mock 方法
DataPermissionContextHolder.add(AnnotationUtils.findAnnotation(TestClass04.class, DataPermission.class));
// 调用
List<DataPermissionRule> result = dataPermissionRuleFactory.getDataPermissionRule(mappedStatementId);
// 断言
assertEquals(1, result.size());
assertEquals(DataPermissionRule01.class, result.get(0).getClass());
}
@Test
public void testGetDataPermissionRule_05() {
// 准备参数
String mappedStatementId = randomString();
// mock 方法
DataPermissionContextHolder.add(AnnotationUtils.findAnnotation(TestClass05.class, DataPermission.class));
// 调用
List<DataPermissionRule> result = dataPermissionRuleFactory.getDataPermissionRule(mappedStatementId);
// 断言
assertEquals(1, result.size());
assertEquals(DataPermissionRule02.class, result.get(0).getClass());
}
@Test
public void testGetDataPermissionRule_06() {
// 准备参数
String mappedStatementId = randomString();
// mock 方法
DataPermissionContextHolder.add(AnnotationUtils.findAnnotation(TestClass06.class, DataPermission.class));
// 调用
List<DataPermissionRule> result = dataPermissionRuleFactory.getDataPermissionRule(mappedStatementId);
// 断言
assertSame(rules, result);
}
@DataPermission(enable = false)
static class TestClass03 {}
@DataPermission(includeRules = DataPermissionRule01.class)
static class TestClass04 {}
@DataPermission(excludeRules = DataPermissionRule01.class)
static class TestClass05 {}
@DataPermission
static class TestClass06 {}
static class DataPermissionRule01 implements DataPermissionRule {
@Override
public Set<String> getTableNames() {
return null;
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
return null;
}
}
static class DataPermissionRule02 implements DataPermissionRule {
@Override
public Set<String> getTableNames() {
return null;
}
@Override
public Expression getExpression(String tableName, Alias tableAlias) {
return null;
}
}
}

View File

@ -0,0 +1,238 @@
package cn.iocoder.yudao.framework.datapermission.core.rule.dept;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
import cn.iocoder.yudao.framework.security.core.LoginUser;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.framework.test.core.ut.BaseMockitoUnitTest;
import cn.iocoder.yudao.module.system.api.permission.PermissionApi;
import cn.iocoder.yudao.module.system.api.permission.dto.DeptDataPermissionRespDTO;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import java.util.Map;
import static cn.iocoder.yudao.framework.datapermission.core.rule.dept.DeptDataPermissionRule.EXPRESSION_NULL;
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomPojo;
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.randomString;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
/**
* {@link DeptDataPermissionRule} 的单元测试
*
* @author 芋道源码
*/
class DeptDataPermissionRuleTest extends BaseMockitoUnitTest {
@InjectMocks
private DeptDataPermissionRule rule;
@Mock
private PermissionApi permissionApi;
@BeforeEach
@SuppressWarnings("unchecked")
public void setUp() {
// 清空 rule
rule.getTableNames().clear();
((Map<String, String>) ReflectUtil.getFieldValue(rule, "deptColumns")).clear();
((Map<String, String>) ReflectUtil.getFieldValue(rule, "deptColumns")).clear();
}
@Test // 无 LoginUser
public void testGetExpression_noLoginUser() {
// 准备参数
String tableName = randomString();
Alias tableAlias = new Alias(randomString());
// mock 方法
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertNull(expression);
}
@Test // 无数据权限时
public void testGetExpression_noDeptDataPermission() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法permissionApi 返回 null
when(permissionApi.getDeptDataPermission(eq(loginUser.getId()))).thenReturn(null);
// 调用
NullPointerException exception = assertThrows(NullPointerException.class,
() -> rule.getExpression(tableName, tableAlias));
// 断言
assertEquals("LoginUser(1) Table(t_user/u) 未返回数据权限", exception.getMessage());
}
}
@Test // 全部数据权限
public void testGetExpression_allDeptDataPermission() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法LoginUser
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法DeptDataPermissionRespDTO
DeptDataPermissionRespDTO deptDataPermission = new DeptDataPermissionRespDTO().setAll(true);
when(permissionApi.getDeptDataPermission(same(1L))).thenReturn(deptDataPermission);
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertNull(expression);
assertSame(deptDataPermission, loginUser.getContext(DeptDataPermissionRule.CONTEXT_KEY, DeptDataPermissionRespDTO.class));
}
}
@Test // 即不能查看部门,又不能查看自己,则说明 100% 无权限
public void testGetExpression_noDept_noSelf() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法LoginUser
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法DeptDataPermissionRespDTO
DeptDataPermissionRespDTO deptDataPermission = new DeptDataPermissionRespDTO();
when(permissionApi.getDeptDataPermission(same(1L))).thenReturn(deptDataPermission);
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertEquals("null = null", expression.toString());
assertSame(deptDataPermission, loginUser.getContext(DeptDataPermissionRule.CONTEXT_KEY, DeptDataPermissionRespDTO.class));
}
}
@Test // 拼接 Dept 和 User 的条件(字段都不符合)
public void testGetExpression_noDeptColumn_noSelfColumn() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法LoginUser
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法DeptDataPermissionRespDTO
DeptDataPermissionRespDTO deptDataPermission = new DeptDataPermissionRespDTO()
.setDeptIds(SetUtils.asSet(10L, 20L)).setSelf(true);
when(permissionApi.getDeptDataPermission(same(1L))).thenReturn(deptDataPermission);
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertSame(EXPRESSION_NULL, expression);
assertSame(deptDataPermission, loginUser.getContext(DeptDataPermissionRule.CONTEXT_KEY, DeptDataPermissionRespDTO.class));
}
}
@Test // 拼接 Dept 和 User 的条件self 符合)
public void testGetExpression_noDeptColumn_yesSelfColumn() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法LoginUser
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法DeptDataPermissionRespDTO
DeptDataPermissionRespDTO deptDataPermission = new DeptDataPermissionRespDTO()
.setSelf(true);
when(permissionApi.getDeptDataPermission(same(1L))).thenReturn(deptDataPermission);
// 添加 user 字段配置
rule.addUserColumn("t_user", "id");
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertEquals("u.id = 1", expression.toString());
assertSame(deptDataPermission, loginUser.getContext(DeptDataPermissionRule.CONTEXT_KEY, DeptDataPermissionRespDTO.class));
}
}
@Test // 拼接 Dept 和 User 的条件dept 符合)
public void testGetExpression_yesDeptColumn_noSelfColumn() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法LoginUser
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法DeptDataPermissionRespDTO
DeptDataPermissionRespDTO deptDataPermission = new DeptDataPermissionRespDTO()
.setDeptIds(CollUtil.newLinkedHashSet(10L, 20L));
when(permissionApi.getDeptDataPermission(same(1L))).thenReturn(deptDataPermission);
// 添加 dept 字段配置
rule.addDeptColumn("t_user", "dept_id");
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertEquals("u.dept_id IN (10, 20)", expression.toString());
assertSame(deptDataPermission, loginUser.getContext(DeptDataPermissionRule.CONTEXT_KEY, DeptDataPermissionRespDTO.class));
}
}
@Test // 拼接 Dept 和 User 的条件dept + self 符合)
public void testGetExpression_yesDeptColumn_yesSelfColumn() {
try (MockedStatic<SecurityFrameworkUtils> securityFrameworkUtilsMock
= mockStatic(SecurityFrameworkUtils.class)) {
// 准备参数
String tableName = "t_user";
Alias tableAlias = new Alias("u");
// mock 方法LoginUser
LoginUser loginUser = randomPojo(LoginUser.class, o -> o.setId(1L)
.setUserType(UserTypeEnum.ADMIN.getValue()));
securityFrameworkUtilsMock.when(SecurityFrameworkUtils::getLoginUser).thenReturn(loginUser);
// mock 方法DeptDataPermissionRespDTO
DeptDataPermissionRespDTO deptDataPermission = new DeptDataPermissionRespDTO()
.setDeptIds(CollUtil.newLinkedHashSet(10L, 20L)).setSelf(true);
when(permissionApi.getDeptDataPermission(same(1L))).thenReturn(deptDataPermission);
// 添加 user 字段配置
rule.addUserColumn("t_user", "id");
// 添加 dept 字段配置
rule.addDeptColumn("t_user", "dept_id");
// 调用
Expression expression = rule.getExpression(tableName, tableAlias);
// 断言
assertEquals("(u.dept_id IN (10, 20) OR u.id = 1)", expression.toString());
assertSame(deptDataPermission, loginUser.getContext(DeptDataPermissionRule.CONTEXT_KEY, DeptDataPermissionRespDTO.class));
}
}
}

View File

@ -0,0 +1,15 @@
package cn.iocoder.yudao.framework.datapermission.core.util;
import cn.iocoder.yudao.framework.datapermission.core.aop.DataPermissionContextHolder;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class DataPermissionUtilsTest {
@Test
public void testExecuteIgnore() {
DataPermissionUtils.executeIgnore(() -> assertFalse(DataPermissionContextHolder.get().enable()));
}
}

View File

@ -0,0 +1,54 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>yudao-framework</artifactId>
<groupId>cn.iocoder.boot</groupId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>yudao-spring-boot-starter-biz-ip</artifactId>
<packaging>jar</packaging>
<name>${project.artifactId}</name>
<description>IP 拓展,支持如下功能:
1. IP 功能:查询 IP 对应的城市信息
基于 https://gitee.com/lionsoul/ip2region 实现
2. 城市功能:查询城市编码对应的城市信息
基于 https://github.com/modood/Administrative-divisions-of-China 实现
</description>
<url>https://github.com/YunaiV/ruoyi-vue-pro</url>
<dependencies>
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-common</artifactId>
</dependency>
<!-- IP地址检索 -->
<dependency>
<groupId>org.lionsoul</groupId>
<artifactId>ip2region</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<scope>provided</scope> <!-- 设置为 provided只有工具类需要使用到 -->
</dependency>
<!-- Test 测试相关 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,55 @@
package cn.iocoder.yudao.framework.ip.core;
import cn.iocoder.yudao.framework.ip.core.enums.AreaTypeEnum;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* 区域节点,包括国家、省份、城市、地区等信息
*
* 数据可见 resources/area.csv 文件
*
* @author 芋道源码
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Area {
/**
* 编号 - 全球,即根目录
*/
public static final Integer ID_GLOBAL = 0;
/**
* 编号 - 中国
*/
public static final Integer ID_CHINA = 1;
/**
* 编号
*/
private Integer id;
/**
* 名字
*/
private String name;
/**
* 类型
*
* 枚举 {@link AreaTypeEnum}
*/
private Integer type;
/**
* 父节点
*/
private Area parent;
/**
* 子节点
*/
private List<Area> children;
}

View File

@ -0,0 +1,39 @@
package cn.iocoder.yudao.framework.ip.core.enums;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* 区域类型枚举
*
* @author 芋道源码
*/
@AllArgsConstructor
@Getter
public enum AreaTypeEnum implements IntArrayValuable {
COUNTRY(1, "国家"),
PROVINCE(2, "省份"),
CITY(3, "城市"),
DISTRICT(4, "地区"), // 县、镇、区等
;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AreaTypeEnum::getType).toArray();
/**
* 类型
*/
private final Integer type;
/**
* 名字
*/
private final String name;
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -0,0 +1,214 @@
package cn.iocoder.yudao.framework.ip.core.utils;
import cn.hutool.core.io.resource.ResourceUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.text.csv.CsvRow;
import cn.hutool.core.text.csv.CsvUtil;
import cn.iocoder.yudao.framework.common.util.object.ObjectUtils;
import cn.iocoder.yudao.framework.ip.core.Area;
import cn.iocoder.yudao.framework.ip.core.enums.AreaTypeEnum;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.findFirst;
/**
* 区域工具类
*
* @author 芋道源码
*/
@Slf4j
public class AreaUtils {
/**
* 初始化 SEARCHER
*/
@SuppressWarnings("InstantiationOfUtilityClass")
private final static AreaUtils INSTANCE = new AreaUtils();
/**
* Area 内存缓存,提升访问速度
*/
private static Map<Integer, Area> areas;
private AreaUtils() {
long now = System.currentTimeMillis();
areas = new HashMap<>();
areas.put(Area.ID_GLOBAL, new Area(Area.ID_GLOBAL, "全球", 0,
null, new ArrayList<>()));
// 从 csv 中加载数据
List<CsvRow> rows = CsvUtil.getReader().read(ResourceUtil.getUtf8Reader("area.csv")).getRows();
rows.remove(0); // 删除 header
for (CsvRow row : rows) {
// 创建 Area 对象
Area area = new Area(Integer.valueOf(row.get(0)), row.get(1), Integer.valueOf(row.get(2)),
null, new ArrayList<>());
// 添加到 areas 中
areas.put(area.getId(), area);
}
// 构建父子关系:因为 Area 中没有 parentId 字段,所以需要重复读取
for (CsvRow row : rows) {
Area area = areas.get(Integer.valueOf(row.get(0))); // 自己
Area parent = areas.get(Integer.valueOf(row.get(3))); // 父
Assert.isTrue(area != parent, "{}:父子节点相同", area.getName());
area.setParent(parent);
parent.getChildren().add(area);
}
log.info("启动加载 AreaUtils 成功,耗时 ({}) 毫秒", System.currentTimeMillis() - now);
}
/**
* 获得指定编号对应的区域
*
* @param id 区域编号
* @return 区域
*/
public static Area getArea(Integer id) {
return areas.get(id);
}
/**
* 获得指定区域对应的编号
*
* @param pathStr 区域路径,例如说:河南省/石家庄市/新华区
* @return 区域
*/
public static Area parseArea(String pathStr) {
String[] paths = pathStr.split("/");
Area area = null;
for (String path : paths) {
if (area == null) {
area = findFirst(areas.values(), item -> item.getName().equals(path));
} else {
area = findFirst(area.getChildren(), item -> item.getName().equals(path));
}
}
return area;
}
/**
* 获取所有节点的全路径名称如:河南省/石家庄市/新华区
*
* @param areas 地区树
* @return 所有节点的全路径名称
*/
public static List<String> getAreaNodePathList(List<Area> areas) {
List<String> paths = new ArrayList<>();
areas.forEach(area -> getAreaNodePathList(area, "", paths));
return paths;
}
/**
* 构建一棵树的所有节点的全路径名称,并将其存储为 "祖先/父级/子级" 的形式
*
* @param node 父节点
* @param path 全路径名称
* @param paths 全路径名称列表,省份/城市/地区
*/
private static void getAreaNodePathList(Area node, String path, List<String> paths) {
if (node == null) {
return;
}
// 构建当前节点的路径
String currentPath = path.isEmpty() ? node.getName() : path + "/" + node.getName();
paths.add(currentPath);
// 递归遍历子节点
for (Area child : node.getChildren()) {
getAreaNodePathList(child, currentPath, paths);
}
}
/**
* 格式化区域
*
* @param id 区域编号
* @return 格式化后的区域
*/
public static String format(Integer id) {
return format(id, " ");
}
/**
* 格式化区域
*
* 例如说:
* 1. id = “静安区”时:上海 上海市 静安区
* 2. id = “上海市”时:上海 上海市
* 3. id = “上海”时:上海
* 4. id = “美国”时:美国
* 当区域在中国时,默认不显示中国
*
* @param id 区域编号
* @param separator 分隔符
* @return 格式化后的区域
*/
public static String format(Integer id, String separator) {
// 获得区域
Area area = areas.get(id);
if (area == null) {
return null;
}
// 格式化
StringBuilder sb = new StringBuilder();
for (int i = 0; i < AreaTypeEnum.values().length; i++) { // 避免死循环
sb.insert(0, area.getName());
// “递归”父节点
area = area.getParent();
if (area == null
|| ObjectUtils.equalsAny(area.getId(), Area.ID_GLOBAL, Area.ID_CHINA)) { // 跳过父节点为中国的情况
break;
}
sb.insert(0, separator);
}
return sb.toString();
}
/**
* 获取指定类型的区域列表
*
* @param type 区域类型
* @param func 转换函数
* @param <T> 结果类型
* @return 区域列表
*/
public static <T> List<T> getByType(AreaTypeEnum type, Function<Area, T> func) {
return convertList(areas.values(), func, area -> type.getType().equals(area.getType()));
}
/**
* 根据区域编号、上级区域类型,获取上级区域编号
*
* @param id 区域编号
* @param type 区域类型
* @return 上级区域编号
*/
public static Integer getParentIdByType(Integer id, @NonNull AreaTypeEnum type) {
for (int i = 0; i < Byte.MAX_VALUE; i++) {
Area area = AreaUtils.getArea(id);
if (area == null) {
return null;
}
// 情况一:匹配到,返回它
if (type.getType().equals(area.getType())) {
return area.getId();
}
// 情况二:找到根节点,返回空
if (area.getParent() == null || area.getParent().getId() == null) {
return null;
}
// 其它:继续向上查找
id = area.getParent().getId();
}
return null;
}
}

View File

@ -0,0 +1,87 @@
package cn.iocoder.yudao.framework.ip.core.utils;
import cn.hutool.core.io.resource.ResourceUtil;
import cn.iocoder.yudao.framework.ip.core.Area;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.lionsoul.ip2region.xdb.Searcher;
import java.io.IOException;
/**
* IP 工具类
*
* IP 数据源来自 ip2region.xdb 精简版,基于 <a href="https://gitee.com/zhijiantianya/ip2region"/> 项目
*
* @author wanglhup
*/
@Slf4j
public class IPUtils {
/**
* 初始化 SEARCHER
*/
@SuppressWarnings("InstantiationOfUtilityClass")
private final static IPUtils INSTANCE = new IPUtils();
/**
* IP 查询器,启动加载到内存中
*/
private static Searcher SEARCHER;
/**
* 私有化构造
*/
private IPUtils() {
try {
long now = System.currentTimeMillis();
byte[] bytes = ResourceUtil.readBytes("ip2region.xdb");
SEARCHER = Searcher.newWithBuffer(bytes);
log.info("启动加载 IPUtils 成功,耗时 ({}) 毫秒", System.currentTimeMillis() - now);
} catch (IOException e) {
log.error("启动加载 IPUtils 失败", e);
}
}
/**
* 查询 IP 对应的地区编号
*
* @param ip IP 地址,格式为 127.0.0.1
* @return 地区id
*/
@SneakyThrows
public static Integer getAreaId(String ip) {
return Integer.parseInt(SEARCHER.search(ip.trim()));
}
/**
* 查询 IP 对应的地区编号
*
* @param ip IP 地址的时间戳,格式参考{@link Searcher#checkIP(String)} 的返回
* @return 地区编号
*/
@SneakyThrows
public static Integer getAreaId(long ip) {
return Integer.parseInt(SEARCHER.search(ip));
}
/**
* 查询 IP 对应的地区
*
* @param ip IP 地址,格式为 127.0.0.1
* @return 地区
*/
public static Area getArea(String ip) {
return AreaUtils.getArea(getAreaId(ip));
}
/**
* 查询 IP 对应的地区
*
* @param ip IP 地址的时间戳,格式参考{@link Searcher#checkIP(String)} 的返回
* @return 地区
*/
public static Area getArea(long ip) {
return AreaUtils.getArea(getAreaId(ip));
}
}

View File

@ -0,0 +1,11 @@
/**
* IP 拓展,支持如下功能:
*
* 1. IP 功能:查询 IP 对应的城市信息
* 基于 https://gitee.com/lionsoul/ip2region 实现
* 2. 城市功能:查询城市编码对应的城市信息
* 基于 https://github.com/modood/Administrative-divisions-of-China 实现
*
* @author 芋道源码
*/
package cn.iocoder.yudao.framework.ip;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.framework.ip.core.utils;
import cn.iocoder.yudao.framework.ip.core.Area;
import cn.iocoder.yudao.framework.ip.core.enums.AreaTypeEnum;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* {@link AreaUtils} 的单元测试
*
* @author 芋道源码
*/
public class AreaUtilsTest {
@Test
public void testGetArea() {
// 调用:北京
Area area = AreaUtils.getArea(110100);
// 断言
assertEquals(area.getId(), 110100);
assertEquals(area.getName(), "北京市");
assertEquals(area.getType(), AreaTypeEnum.CITY.getType());
assertEquals(area.getParent().getId(), 110000);
assertEquals(area.getChildren().size(), 16);
}
@Test
public void testFormat() {
assertEquals(AreaUtils.format(110105), "北京 北京市 朝阳区");
assertEquals(AreaUtils.format(1), "中国");
assertEquals(AreaUtils.format(2), "蒙古");
}
}

View File

@ -0,0 +1,47 @@
package cn.iocoder.yudao.framework.ip.core.utils;
import cn.iocoder.yudao.framework.ip.core.Area;
import org.junit.jupiter.api.Test;
import org.lionsoul.ip2region.xdb.Searcher;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* {@link IPUtils} 的单元测试
*
* @author wanglhup
*/
public class IPUtilsTest {
@Test
public void testGetAreaId_string() {
// 120.202.4.0|120.202.4.255|420600
Integer areaId = IPUtils.getAreaId("120.202.4.50");
assertEquals(420600, areaId);
}
@Test
public void testGetAreaId_long() throws Exception {
// 120.203.123.0|120.203.133.255|360900
long ip = Searcher.checkIP("120.203.123.250");
Integer areaId = IPUtils.getAreaId(ip);
assertEquals(360900, areaId);
}
@Test
public void testGetArea_string() {
// 120.202.4.0|120.202.4.255|420600
Area area = IPUtils.getArea("120.202.4.50");
assertEquals("襄阳市", area.getName());
}
@Test
public void testGetArea_long() throws Exception {
// 120.203.123.0|120.203.133.255|360900
long ip = Searcher.checkIP("120.203.123.252");
Area area = IPUtils.getArea(ip);
assertEquals("宜春市", area.getName());
}
}

View File

@ -0,0 +1,52 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-framework</artifactId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>yudao-spring-boot-starter-biz-operatelog</artifactId>
<packaging>jar</packaging>
<name>${project.artifactId}</name>
<description>操作日志</description>
<url>https://github.com/YunaiV/ruoyi-vue-pro</url>
<dependencies>
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-common</artifactId>
</dependency>
<!-- Spring 核心 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- Web 相关 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-web</artifactId>
<scope>provided</scope>
</dependency>
<!-- 业务组件 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-module-system-api</artifactId>
<version>${revision}</version>
</dependency>
<!-- 工具类相关 -->
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,23 @@
package cn.iocoder.yudao.framework.operatelog.config;
import cn.iocoder.yudao.framework.operatelog.core.aop.OperateLogAspect;
import cn.iocoder.yudao.framework.operatelog.core.service.OperateLogFrameworkService;
import cn.iocoder.yudao.framework.operatelog.core.service.OperateLogFrameworkServiceImpl;
import cn.iocoder.yudao.module.system.api.logger.OperateLogApi;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.context.annotation.Bean;
@AutoConfiguration
public class YudaoOperateLogAutoConfiguration {
@Bean
public OperateLogAspect operateLogAspect() {
return new OperateLogAspect();
}
@Bean
public OperateLogFrameworkService operateLogFrameworkService(OperateLogApi operateLogApi) {
return new OperateLogFrameworkServiceImpl(operateLogApi);
}
}

View File

@ -0,0 +1,57 @@
package cn.iocoder.yudao.framework.operatelog.core.annotations;
import cn.iocoder.yudao.framework.operatelog.core.enums.OperateTypeEnum;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* 操作日志注解
*
* @author 芋道源码
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface OperateLog {
// ========== 模块字段 ==========
/**
* 操作模块
*
* 为空时,会尝试读取 {@link Tag#name()} 属性
*/
String module() default "";
/**
* 操作名
*
* 为空时,会尝试读取 {@link Operation#summary()} 属性
*/
String name() default "";
/**
* 操作分类
*
* 实际并不是数组,因为枚举不能设置 null 作为默认值
*/
OperateTypeEnum[] type() default {};
// ========== 开关字段 ==========
/**
* 是否记录操作日志
*/
boolean enable() default true;
/**
* 是否记录方法参数
*/
boolean logArgs() default true;
/**
* 是否记录方法结果的数据
*/
boolean logResultData() default true;
}

View File

@ -0,0 +1,375 @@
package cn.iocoder.yudao.framework.operatelog.core.aop;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.common.util.monitor.TracerUtils;
import cn.iocoder.yudao.framework.common.util.servlet.ServletUtils;
import cn.iocoder.yudao.framework.operatelog.core.enums.OperateTypeEnum;
import cn.iocoder.yudao.framework.operatelog.core.service.OperateLog;
import cn.iocoder.yudao.framework.operatelog.core.service.OperateLogFrameworkService;
import cn.iocoder.yudao.framework.web.core.util.WebFrameworkUtils;
import com.google.common.collect.Maps;
import io.swagger.v3.oas.annotations.tags.Tag;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.time.LocalDateTime;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import static cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR;
import static cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants.SUCCESS;
/**
* 拦截使用 @OperateLog 注解,如果满足条件,则生成操作日志。
* 满足如下任一条件,则会进行记录:
* 1. 使用 @ApiOperation + 非 @GetMapping
* 2. 使用 @OperateLog 注解
* <p>
* 但是,如果声明 @OperateLog 注解时,将 enable 属性设置为 false 时,强制不记录。
*
* @author 芋道源码
*/
@Aspect
@Slf4j
public class OperateLogAspect {
/**
* 用于记录操作内容的上下文
*
* @see OperateLog#getContent()
*/
private static final ThreadLocal<String> CONTENT = new ThreadLocal<>();
/**
* 用于记录拓展字段的上下文
*
* @see OperateLog#getExts()
*/
private static final ThreadLocal<Map<String, Object>> EXTS = new ThreadLocal<>();
@Resource
private OperateLogFrameworkService operateLogFrameworkService;
@Around("@annotation(operation)")
public Object around(ProceedingJoinPoint joinPoint, Operation operation) throws Throwable {
// 可能也添加了 @ApiOperation 注解
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog = getMethodAnnotation(joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog.class);
return around0(joinPoint, operateLog, operation);
}
@Around("!@annotation(io.swagger.v3.oas.annotations.Operation) && @annotation(operateLog)")
// 兼容处理,只添加 @OperateLog 注解的情况
public Object around(ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog) throws Throwable {
return around0(joinPoint, operateLog, null);
}
private Object around0(ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog,
Operation operation) throws Throwable {
// 目前,只有管理员,才记录操作日志!所以非管理员,直接调用,不进行记录
Integer userType = WebFrameworkUtils.getLoginUserType();
if (!Objects.equals(userType, UserTypeEnum.ADMIN.getValue())) {
return joinPoint.proceed();
}
// 记录开始时间
LocalDateTime startTime = LocalDateTime.now();
try {
// 执行原有方法
Object result = joinPoint.proceed();
// 记录正常执行时的操作日志
this.log(joinPoint, operateLog, operation, startTime, result, null);
return result;
} catch (Throwable exception) {
this.log(joinPoint, operateLog, operation, startTime, null, exception);
throw exception;
} finally {
clearThreadLocal();
}
}
public static void setContent(String content) {
CONTENT.set(content);
}
public static void addExt(String key, Object value) {
if (EXTS.get() == null) {
EXTS.set(new HashMap<>());
}
EXTS.get().put(key, value);
}
private static void clearThreadLocal() {
CONTENT.remove();
EXTS.remove();
}
private void log(ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog,
Operation operation,
LocalDateTime startTime, Object result, Throwable exception) {
try {
// 判断不记录的情况
if (!isLogEnable(joinPoint, operateLog)) {
return;
}
// 真正记录操作日志
this.log0(joinPoint, operateLog, operation, startTime, result, exception);
} catch (Throwable ex) {
log.error("[log][记录操作日志时,发生异常,其中参数是 joinPoint({}) operateLog({}) apiOperation({}) result({}) exception({}) ]",
joinPoint, operateLog, operation, result, exception, ex);
}
}
private void log0(ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog,
Operation operation,
LocalDateTime startTime, Object result, Throwable exception) {
OperateLog operateLogObj = new OperateLog();
// 补全通用字段
operateLogObj.setTraceId(TracerUtils.getTraceId());
operateLogObj.setStartTime(startTime);
// 补充用户信息
fillUserFields(operateLogObj);
// 补全模块信息
fillModuleFields(operateLogObj, joinPoint, operateLog, operation);
// 补全请求信息
fillRequestFields(operateLogObj);
// 补全方法信息
fillMethodFields(operateLogObj, joinPoint, operateLog, startTime, result, exception);
// 异步记录日志
operateLogFrameworkService.createOperateLog(operateLogObj);
}
private static void fillUserFields(OperateLog operateLogObj) {
operateLogObj.setUserId(WebFrameworkUtils.getLoginUserId());
operateLogObj.setUserType(WebFrameworkUtils.getLoginUserType());
}
private static void fillModuleFields(OperateLog operateLogObj,
ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog,
Operation operation) {
// module 属性
if (operateLog != null) {
operateLogObj.setModule(operateLog.module());
}
if (StrUtil.isEmpty(operateLogObj.getModule())) {
Tag tag = getClassAnnotation(joinPoint, Tag.class);
if (tag != null) {
// 优先读取 @Tag 的 name 属性
if (StrUtil.isNotEmpty(tag.name())) {
operateLogObj.setModule(tag.name());
}
// 没有的话,读取 @API 的 description 属性
if (StrUtil.isEmpty(operateLogObj.getModule()) && ArrayUtil.isNotEmpty(tag.description())) {
operateLogObj.setModule(tag.description());
}
}
}
// name 属性
if (operateLog != null) {
operateLogObj.setName(operateLog.name());
}
if (StrUtil.isEmpty(operateLogObj.getName()) && operation != null) {
operateLogObj.setName(operation.summary());
}
// type 属性
if (operateLog != null && ArrayUtil.isNotEmpty(operateLog.type())) {
operateLogObj.setType(operateLog.type()[0].getType());
}
if (operateLogObj.getType() == null) {
RequestMethod requestMethod = obtainFirstMatchRequestMethod(obtainRequestMethod(joinPoint));
OperateTypeEnum operateLogType = convertOperateLogType(requestMethod);
operateLogObj.setType(operateLogType != null ? operateLogType.getType() : null);
}
// content 和 exts 属性
operateLogObj.setContent(CONTENT.get());
operateLogObj.setExts(EXTS.get());
}
private static void fillRequestFields(OperateLog operateLogObj) {
// 获得 Request 对象
HttpServletRequest request = ServletUtils.getRequest();
if (request == null) {
return;
}
// 补全请求信息
operateLogObj.setRequestMethod(request.getMethod());
operateLogObj.setRequestUrl(request.getRequestURI());
operateLogObj.setUserIp(ServletUtils.getClientIP(request));
operateLogObj.setUserAgent(ServletUtils.getUserAgent(request));
}
private static void fillMethodFields(OperateLog operateLogObj,
ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog,
LocalDateTime startTime, Object result, Throwable exception) {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
operateLogObj.setJavaMethod(methodSignature.toString());
if (operateLog == null || operateLog.logArgs()) {
operateLogObj.setJavaMethodArgs(obtainMethodArgs(joinPoint));
}
if (operateLog == null || operateLog.logResultData()) {
operateLogObj.setResultData(obtainResultData(result));
}
operateLogObj.setDuration((int) (LocalDateTimeUtil.between(startTime, LocalDateTime.now()).toMillis()));
// (正常)处理 resultCode 和 resultMsg 字段
if (result instanceof CommonResult) {
CommonResult<?> commonResult = (CommonResult<?>) result;
operateLogObj.setResultCode(commonResult.getCode());
operateLogObj.setResultMsg(commonResult.getMsg());
} else {
operateLogObj.setResultCode(SUCCESS.getCode());
}
// (异常)处理 resultCode 和 resultMsg 字段
if (exception != null) {
operateLogObj.setResultCode(INTERNAL_SERVER_ERROR.getCode());
operateLogObj.setResultMsg(ExceptionUtil.getRootCauseMessage(exception));
}
}
private static boolean isLogEnable(ProceedingJoinPoint joinPoint,
cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog operateLog) {
// 有 @OperateLog 注解的情况下
if (operateLog != null) {
return operateLog.enable();
}
// 没有 @ApiOperation 注解的情况下,只记录 POST、PUT、DELETE 的情况
return obtainFirstLogRequestMethod(obtainRequestMethod(joinPoint)) != null;
}
private static RequestMethod obtainFirstLogRequestMethod(RequestMethod[] requestMethods) {
if (ArrayUtil.isEmpty(requestMethods)) {
return null;
}
return Arrays.stream(requestMethods).filter(requestMethod ->
requestMethod == RequestMethod.POST
|| requestMethod == RequestMethod.PUT
|| requestMethod == RequestMethod.DELETE)
.findFirst().orElse(null);
}
private static RequestMethod obtainFirstMatchRequestMethod(RequestMethod[] requestMethods) {
if (ArrayUtil.isEmpty(requestMethods)) {
return null;
}
// 优先,匹配最优的 POST、PUT、DELETE
RequestMethod result = obtainFirstLogRequestMethod(requestMethods);
if (result != null) {
return result;
}
// 然后,匹配次优的 GET
result = Arrays.stream(requestMethods).filter(requestMethod -> requestMethod == RequestMethod.GET)
.findFirst().orElse(null);
if (result != null) {
return result;
}
// 兜底,获得第一个
return requestMethods[0];
}
private static OperateTypeEnum convertOperateLogType(RequestMethod requestMethod) {
if (requestMethod == null) {
return null;
}
switch (requestMethod) {
case GET:
return OperateTypeEnum.GET;
case POST:
return OperateTypeEnum.CREATE;
case PUT:
return OperateTypeEnum.UPDATE;
case DELETE:
return OperateTypeEnum.DELETE;
default:
return OperateTypeEnum.OTHER;
}
}
private static RequestMethod[] obtainRequestMethod(ProceedingJoinPoint joinPoint) {
RequestMapping requestMapping = AnnotationUtils.getAnnotation( // 使用 Spring 的工具类,可以处理 @RequestMapping 别名注解
((MethodSignature) joinPoint.getSignature()).getMethod(), RequestMapping.class);
return requestMapping != null ? requestMapping.method() : new RequestMethod[]{};
}
@SuppressWarnings("SameParameterValue")
private static <T extends Annotation> T getMethodAnnotation(ProceedingJoinPoint joinPoint, Class<T> annotationClass) {
return ((MethodSignature) joinPoint.getSignature()).getMethod().getAnnotation(annotationClass);
}
@SuppressWarnings("SameParameterValue")
private static <T extends Annotation> T getClassAnnotation(ProceedingJoinPoint joinPoint, Class<T> annotationClass) {
return ((MethodSignature) joinPoint.getSignature()).getMethod().getDeclaringClass().getAnnotation(annotationClass);
}
private static String obtainMethodArgs(ProceedingJoinPoint joinPoint) {
// TODO 提升:参数脱敏和忽略
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
String[] argNames = methodSignature.getParameterNames();
Object[] argValues = joinPoint.getArgs();
// 拼接参数
Map<String, Object> args = Maps.newHashMapWithExpectedSize(argValues.length);
for (int i = 0; i < argNames.length; i++) {
String argName = argNames[i];
Object argValue = argValues[i];
// 被忽略时,标记为 ignore 字符串,避免和 null 混在一起
args.put(argName, !isIgnoreArgs(argValue) ? argValue : "[ignore]");
}
return JsonUtils.toJsonString(args);
}
private static String obtainResultData(Object result) {
// TODO 提升:结果脱敏和忽略
if (result instanceof CommonResult) {
result = ((CommonResult<?>) result).getData();
}
return JsonUtils.toJsonString(result);
}
private static boolean isIgnoreArgs(Object object) {
Class<?> clazz = object.getClass();
// 处理数组的情况
if (clazz.isArray()) {
return IntStream.range(0, Array.getLength(object))
.anyMatch(index -> isIgnoreArgs(Array.get(object, index)));
}
// 递归处理数组、Collection、Map 的情况
if (Collection.class.isAssignableFrom(clazz)) {
return ((Collection<?>) object).stream()
.anyMatch((Predicate<Object>) OperateLogAspect::isIgnoreArgs);
}
if (Map.class.isAssignableFrom(clazz)) {
return isIgnoreArgs(((Map<?, ?>) object).values());
}
// obj
return object instanceof MultipartFile
|| object instanceof HttpServletRequest
|| object instanceof HttpServletResponse
|| object instanceof BindingResult;
}
}

View File

@ -0,0 +1,55 @@
package cn.iocoder.yudao.framework.operatelog.core.enums;
import cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 操作日志的操作类型
*
* @author ruoyi
*/
@Getter
@AllArgsConstructor
public enum OperateTypeEnum {
/**
* 查询
*
* 绝大多数情况下,不会记录查询动作,因为过于大量显得没有意义。
* 在有需要的时候,通过声明 {@link OperateLog} 注解来记录
*/
GET(1),
/**
* 新增
*/
CREATE(2),
/**
* 修改
*/
UPDATE(3),
/**
* 删除
*/
DELETE(4),
/**
* 导出
*/
EXPORT(5),
/**
* 导入
*/
IMPORT(6),
/**
* 其它
*
* 在无法归类时,可以选择使用其它。因为还有操作名可以进一步标识
*/
OTHER(0);
/**
* 类型
*/
private final Integer type;
}

View File

@ -0,0 +1 @@
package cn.iocoder.yudao.framework.operatelog.core;

View File

@ -0,0 +1,110 @@
package cn.iocoder.yudao.framework.operatelog.core.service;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.Map;
/**
* 操作日志
*
* @author 芋道源码
*/
@Data
public class OperateLog {
/**
* 链路追踪编号
*/
private String traceId;
/**
* 用户编号
*/
private Long userId;
/**
* 用户类型
*/
private Integer userType;
/**
* 操作模块
*/
private String module;
/**
* 操作名
*/
private String name;
/**
* 操作分类
*/
private Integer type;
/**
* 操作明细
*/
private String content;
/**
* 拓展字段
*/
private Map<String, Object> exts;
/**
* 请求方法名
*/
private String requestMethod;
/**
* 请求地址
*/
private String requestUrl;
/**
* 用户 IP
*/
private String userIp;
/**
* 浏览器 UserAgent
*/
private String userAgent;
/**
* Java 方法名
*/
private String javaMethod;
/**
* Java 方法的参数
*/
private String javaMethodArgs;
/**
* 开始时间
*/
private LocalDateTime startTime;
/**
* 执行时长,单位:毫秒
*/
private Integer duration;
/**
* 结果码
*/
private Integer resultCode;
/**
* 结果提示
*/
private String resultMsg;
/**
* 结果数据
*/
private String resultData;
}

View File

@ -0,0 +1,17 @@
package cn.iocoder.yudao.framework.operatelog.core.service;
/**
* 操作日志 Framework Service 接口
*
* @author 芋道源码
*/
public interface OperateLogFrameworkService {
/**
* 记录操作日志
*
* @param operateLog 操作日志请求
*/
void createOperateLog(OperateLog operateLog);
}

View File

@ -0,0 +1,28 @@
package cn.iocoder.yudao.framework.operatelog.core.service;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.module.system.api.logger.OperateLogApi;
import cn.iocoder.yudao.module.system.api.logger.dto.OperateLogCreateReqDTO;
import lombok.RequiredArgsConstructor;
import org.springframework.scheduling.annotation.Async;
/**
* 操作日志 Framework Service 实现类
*
* 基于 {@link OperateLogApi} 实现,记录操作日志
*
* @author 芋道源码
*/
@RequiredArgsConstructor
public class OperateLogFrameworkServiceImpl implements OperateLogFrameworkService {
private final OperateLogApi operateLogApi;
@Override
@Async
public void createOperateLog(OperateLog operateLog) {
OperateLogCreateReqDTO reqDTO = BeanUtil.toBean(operateLog, OperateLogCreateReqDTO.class);
operateLogApi.createOperateLog(reqDTO);
}
}

View File

@ -0,0 +1,21 @@
package cn.iocoder.yudao.framework.operatelog.core.util;
import cn.iocoder.yudao.framework.operatelog.core.aop.OperateLogAspect;
/**
* 操作日志工具类
* 目前主要的作用,是提供给业务代码,记录操作明细和拓展字段
*
* @author 芋道源码
*/
public class OperateLogUtils {
public static void setContent(String content) {
OperateLogAspect.setContent(content);
}
public static void addExt(String key, Object value) {
OperateLogAspect.addExt(key, value);
}
}

View File

@ -0,0 +1,6 @@
/**
* 用户操作日志:记录用户的操作,用于对用户的操作的审计与追溯,永久保存。
*
* @author 芋道源码
*/
package cn.iocoder.yudao.framework.operatelog;

View File

@ -0,0 +1 @@
cn.iocoder.yudao.framework.operatelog.config.YudaoOperateLogAutoConfiguration

Some files were not shown because too many files have changed in this diff Show More