实现接口幂等
接口幂等
文章目录
1.基于 Token
基于 Token 这种方案的实现思路很简单,整个流程分两步:
- 客户端发送请求,从服务端获取一个 Token 令牌,每次请求获取到的都是一个全新的令牌。
- 客户端发送请求的时候,携带上第一步的令牌,处理请求之前,先校验令牌是否存在,当请求处理成功,就把令牌删除掉。
2. 基于请求参数的校验
拦截器方式
需要引入redis
使用一个拦截器实现,拦截需要接口幂等的接口方法;
判断是否是重复提交的数据
根据url + 参数 + 用户的身份认证token 为key判断幂等
1.搭建一个springboot 项目
pom.xml依赖:
web , redis , commons-lang3
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- redis依赖-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
<version>2.6.4</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
</dependencies>
配置文件
application.yml
server:
port: 8001
spring:
redis:
# 地址
host: 127.0.0.1
# 端口,默认为6379
port: 6379
# 密码
# 连接超时时间
timeout: 10s
lettuce:
pool:
# 连接池中的最小空闲连接
min-idle: 0
# 连接池中的最大空闲连接
max-idle: 8
# 连接池的最大数据库连接数
max-active: 8
# #连接池最大阻塞等待时间(使用负值表示没有限制)
max-wait: -1ms
配置redis配置类
RedisTemplateConfig
@EnableCaching
@Configuration
public class RedisTemplateConfig extends CachingConfigurerSupport {
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
// key采用 String的序列化方式
redisTemplate.setKeySerializer(stringRedisSerializer);
// hash的 key也采用 String的序列化方式
redisTemplate.setHashKeySerializer(stringRedisSerializer);
// value序列化方式采用 jdk
redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer());
// hash的 value序列化方式采用 jdk
redisTemplate.setHashValueSerializer(new JdkSerializationRedisSerializer());
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
}
封装redis的缓存工具类
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.stereotype.Component;
import java.util.concurrent.TimeUnit;
/**
* @author: wenyi
* @create: 2022/12/8
* @Description: 封装redis
*/
@Component
public class RedisCache {
@Autowired
public RedisTemplate redisTemplate;
public <T> void setCacheObject(final String key, final T value, final Integer timeout, final TimeUnit timeUnit) {
redisTemplate.opsForValue().set(key, value, timeout, timeUnit);
}
public <T> T getCacheObject(final String key) {
ValueOperations<String, T> operation = redisTemplate.opsForValue();
return operation.get(key);
}
}
HttpServletRequest 读取一次后,后续就读取不了
构建可重复读取inputStream的request
public class RepeatedlyRequestWrapper extends HttpServletRequestWrapper {
private final byte[] body;
public RepeatedlyRequestWrapper(HttpServletRequest request, ServletResponse response) throws IOException {
super(request);
request.setCharacterEncoding(Constants.UTF8);
response.setCharacterEncoding(Constants.UTF8);
body = HttpHelper.getBodyString(request).getBytes(Constants.UTF8);
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream bais = new ByteArrayInputStream(body);
return new ServletInputStream() {
@Override
public int read() throws IOException {
return bais.read();
}
@Override
public int available() throws IOException {
return body.length;
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
};
}
}
通用常量信息
public class Constants {
/**
* UTF-8 字符集
*/
public static final String UTF8 = "UTF-8";
}
通用http工具封装
/**
* 通用http工具封装
*
*/
public class HttpHelper {
private static final Logger LOGGER = LoggerFactory.getLogger(HttpHelper.class);
public static String getBodyString(ServletRequest request) {
StringBuilder sb = new StringBuilder();
BufferedReader reader = null;
try (InputStream inputStream = request.getInputStream()) {
reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
String line = "";
while ((line = reader.readLine()) != null) {
sb.append(line);
}
} catch (IOException e) {
LOGGER.warn("getBodyString出现问题!");
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
LOGGER.error(ExceptionUtils.getMessage(e));
}
}
}
return sb.toString();
}
}
2.定义接口幂等注解
自定义一个注解,在需要进行幂等性处理的接口上,添加该注解即可,将来这个接口就会自动的进行幂等性处理。
import java.lang.annotation.*;
/**
* 作用在接口方法上的 幂等性处理注解
*/
@Inherited
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RepeatSubmit {
/**
* 间隔时间(ms),小于此时间视为重复提交
*/
int interval() default 5000;
/**
* 提示消息
*/
String message() default "不允许重复提交,请稍候再试";
}
3.实现拦截器
定义抽象拦截器父类
接口方法拦截下来,然后找到接口上的 @RepeatSubmit
注解,调用 isRepeatSubmit
方法去判断是否是重复提交的数据
import com.fasterxml.jackson.databind.ObjectMapper;
import com.ung.myUrl.annotation.RepeatSubmit;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
/**
* @author: wenyi
* @create: 2022/12/8
* @Description: 抽象的拦截器
*/
public abstract class RepeatSubmitInterceptor implements HandlerInterceptor {
/**
* 将接口方法拦截下来,然后找到接口上的 @RepeatSubmit 注解,调用 isRepeatSubmit 方法去判断是否是重复提交的数据
*
* @param request
* @param response
* @param handler
* @return
* @throws Exception
*/
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (handler instanceof HandlerMethod) {
HandlerMethod handlerMethod = (HandlerMethod) handler;
Method method = handlerMethod.getMethod();
RepeatSubmit annotation = method.getAnnotation(RepeatSubmit.class);
if (annotation != null) {
if (this.isRepeatSubmit(request, annotation)) {
Map<String, Object> map = new HashMap<>();
map.put("status", 500);
map.put("msg", annotation.message());
response.setContentType("application/json;charset=utf-8");
response.getWriter().write(new ObjectMapper().writeValueAsString(map));
return false;
}
}
return true;
} else {
return true;
}
}
/**
* 验证是否重复提交由子类实现具体的防重复提交的规则
*
* @param request
* @return
* @throws Exception
*/
public abstract boolean isRepeatSubmit(HttpServletRequest request, RepeatSubmit annotation);
}
子类实现父类的具体判断幂等规则
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.ung.myUrl.annotation.RepeatSubmit;
import com.ung.myUrl.filter.RepeatedlyRequestWrapper;
import com.ung.myUrl.utils.RedisCache;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* @author: wenyi
* @create: 2022/12/8
* @Description: 判断请求url和数据是否和上一次相同,
* * 如果和上次相同,则是重复提交表单。 有效时间为10秒内。
*/
@Component
public class SameUrlDataInterceptor extends RepeatSubmitInterceptor {
public final String REPEAT_PARAMS = "repeatParams";
public final String REPEAT_TIME = "repeatTime";
public final static String REPEAT_SUBMIT_KEY = "REPEAT_SUBMIT_KEY";
private String header = "Authorization";
@Autowired
private RedisCache redisCache;
@SuppressWarnings("unchecked")
@Override
public boolean isRepeatSubmit(HttpServletRequest request, RepeatSubmit annotation) {
String nowParams = "";
if (request instanceof RepeatedlyRequestWrapper) {
RepeatedlyRequestWrapper repeatedlyRequest = (RepeatedlyRequestWrapper) request;
try {
nowParams = repeatedlyRequest.getReader().readLine();
} catch (IOException e) {
e.printStackTrace();
}
}
// body参数为空,获取Parameter的数据
if (StringUtils.isEmpty(nowParams)) {
try {
nowParams = new ObjectMapper().writeValueAsString(request.getParameterMap());
} catch (JsonProcessingException e) {
e.printStackTrace();
}
}
Map<String, Object> nowDataMap = new HashMap<String, Object>();
nowDataMap.put(REPEAT_PARAMS, nowParams);
nowDataMap.put(REPEAT_TIME, System.currentTimeMillis());
// 请求地址(作为存放cache的key值)
String url = request.getRequestURI();
// 唯一值(没有消息头则使用请求地址)
String submitKey = request.getHeader(header);
// 唯一标识(指定key + url + 消息头)
String cacheRepeatKey = REPEAT_SUBMIT_KEY + url + submitKey;
Object sessionObj = redisCache.getCacheObject(cacheRepeatKey);
if (sessionObj != null) {
Map<String, Object> sessionMap = (Map<String, Object>) sessionObj;
if (compareParams(nowDataMap, sessionMap) && compareTime(nowDataMap, sessionMap, annotation.interval())) {
return true;
}
}
redisCache.setCacheObject(cacheRepeatKey, nowDataMap, annotation.interval(), TimeUnit.MILLISECONDS);
return false;
}
/**
* 判断参数是否相同
*/
private boolean compareParams(Map<String, Object> nowMap, Map<String, Object> preMap) {
String nowParams = (String) nowMap.get(REPEAT_PARAMS);
String preParams = (String) preMap.get(REPEAT_PARAMS);
return nowParams.equals(preParams);
}
/**
* 判断两次间隔时间
*/
private boolean compareTime(Map<String, Object> nowMap, Map<String, Object> preMap, int interval) {
long time1 = (Long) nowMap.get(REPEAT_TIME);
long time2 = (Long) preMap.get(REPEAT_TIME);
if ((time1 - time2) < interval) {
return true;
}
return false;
}
}
最后配置拦截器
import com.ung.myUrl.interceptor.RepeatSubmitInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* @author: wenyi
* @create: 2022/12/8
* @Description: 配置拦截器
*/
@Configuration
public class WebConfig implements WebMvcConfigurer {
@Autowired
RepeatSubmitInterceptor repeatSubmitInterceptor;
/**
* 自定义拦截规则
*/
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(repeatSubmitInterceptor)
.addPathPatterns("/**");
}
}
4.使用幂等注解
最后在controller使用
@RestController
public class HelloController {
@PostMapping("/hello")
@RepeatSubmit(interval = 100000)
public String hello(@RequestBody String msg) {
System.out.println("msg = " + msg);
return "hello";
}
}
5.测试使用
post请求 http://localhost:8001/hello
设置token ,这里不需要token认证,随便一个就行
最后请求成功,可以看到redis里面有数据,就是按照 指定key + URL + token
连续再次点击,会提示 不允许重复提交,说明幂等设置成功
AOP方式
实现防刷切面PreventAop.java
大致逻辑为:定义一切面,通过@Prevent注解作为切入点、在该切面的前置通知获取该方法的所有入参并将其Base64编码,将入参Base64编码+完整方法名作为redis的key,入参作为reids的value,@Prevent的value作为redis的expire,存入redis;每次进来这个切面根据入参Base64编码+完整方法名判断redis值是否存在,存在则拦截防刷,不存在则允许调用;
pom依赖
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.6.11</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.ung</groupId>
<artifactId>controller-test</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>controller-test</name>
<description>Demo project for Spring Boot</description>
<packaging>jar</packaging>
<properties>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<!--json以來-->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.68</version>
</dependency>
<!--校验注解-->
<dependency>
<groupId>org.hibernate.validator</groupId>
<artifactId>hibernate-validator</artifactId>
<version>6.0.7.Final</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- redis依赖-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
<version>2.6.4</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
配置文件 配置redis
server:
port: 8001
spring:
redis:
# 地址
host: 127.0.0.1
# 端口,默认为6379
port: 6379
# 密码
# 连接超时时间
timeout: 10s
lettuce:
pool:
# 连接池中的最小空闲连接
min-idle: 0
# 连接池中的最大空闲连接
max-idle: 8
# 连接池的最大数据库连接数
max-active: 8
# #连接池最大阻塞等待时间(使用负值表示没有限制)
max-wait: -1ms
Redis 配置
package com.ung.controllertest.config;
import org.springframework.cache.annotation.CachingConfigurerSupport;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.JdkSerializationRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
@EnableCaching
@Configuration
public class RedisTemplateConfig extends CachingConfigurerSupport {
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(redisConnectionFactory);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
// key采用 String的序列化方式
redisTemplate.setKeySerializer(stringRedisSerializer);
// hash的 key也采用 String的序列化方式
redisTemplate.setHashKeySerializer(stringRedisSerializer);
// value序列化方式采用 jdk
redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer());
// hash的 value序列化方式采用 jdk
redisTemplate.setHashValueSerializer(new JdkSerializationRedisSerializer());
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
}
定义 RedisUtil 操作的工具类
package com.ung.controllertest.utils;
import com.alibaba.fastjson.JSON;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.*;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.concurrent.TimeUnit;
/**
* Redis工具类
*/
@Component
public class RedisUtil {
@Autowired
private RedisTemplate redisTemplate;
@Resource(name="redisTemplate")
private ValueOperations<String, String> valueOperations;
@Resource(name="redisTemplate")
private HashOperations<String, String, Object> hashOperations;
@Resource(name="redisTemplate")
private ListOperations<String, Object> listOperations;
@Resource(name="redisTemplate")
private SetOperations<String, Object> setOperations;
@Resource(name="redisTemplate")
private ZSetOperations<String, Object> zSetOperations;
/** 默认过期时长,单位:秒 */
public final static long DEFAULT_EXPIRE = 60 * 60 * 24;
/** 不设置过期时长 */
public final static long NOT_EXPIRE = -1;
public void set(String key, Object value, long expire){
valueOperations.set(key, toJson(value));
if(expire != NOT_EXPIRE){
redisTemplate.expire(key, expire, TimeUnit.SECONDS);
}
}
public void set(String key, Object value){
set(key, value, DEFAULT_EXPIRE);
}
public <T> T get(String key, Class<T> clazz, long expire) {
String value = valueOperations.get(key);
if(expire != NOT_EXPIRE){
redisTemplate.expire(key, expire, TimeUnit.SECONDS);
}
return value == null ? null : fromJson(value, clazz);
}
public <T> T get(String key, Class<T> clazz) {
return get(key, clazz, NOT_EXPIRE);
}
public String get(String key, long expire) {
String value = valueOperations.get(key);
if(expire != NOT_EXPIRE){
redisTemplate.expire(key, expire, TimeUnit.SECONDS);
}
return value;
}
public String get(String key) {
return get(key, NOT_EXPIRE);
}
public void delete(String key) {
redisTemplate.delete(key);
}
/**
* Object转成JSON数据
*/
private String toJson(Object object){
if(object instanceof Integer || object instanceof Long || object instanceof Float ||
object instanceof Double || object instanceof Boolean || object instanceof String){
return String.valueOf(object);
}
return JSON.toJSONString(object);
}
/**
* JSON数据,转成Object
*/
private <T> T fromJson(String json, Class<T> clazz){
return JSON.parseObject(json, clazz);
}
}
定义注解Prevent
package com.ung.controllertest.annotation;
import com.ung.controllertest.enums.PreventStrategyEnum;
import java.lang.annotation.*;
/**
* 接口防刷注解
* 使用:
* 在相应需要防刷的方法上加上
* 该注解,即可
*/
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface Prevent {
/**
* 限制的时间值(秒)
*
* @return
*/
String value() default "60";
/**
* 提示
*/
String message() default "";
/**
* 策略
*
* @return
*/
PreventStrategyEnum strategy() default PreventStrategyEnum.DEFAULT;
}
定义枚举 设置策略
package com.ung.controllertest.enums;
/**
* 防刷策略枚举
*/
public enum PreventStrategyEnum {
/**
* 默认(1分钟内不允许再次请求)
*/
DEFAULT
}
实现防刷切面PreventAop
package com.ung.controllertest.aop;
import com.alibaba.fastjson.JSON;
import com.ung.controllertest.annotation.Prevent;
import com.ung.controllertest.enums.PreventStrategyEnum;
import com.ung.controllertest.exception.BusinessException;
import com.ung.controllertest.utils.RedisUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Base64;
/**
* 大致逻辑为:
* 定义一切面,通过@Prevent注解作为切入点、在该切面的前置通知获取该方法的所有入参并将其Base64编码,
* 将入参Base64编码+完整方法名作为redis的key,入参作为reids的value,@Prevent的value作为redis的expire,存入redis;
* 每次进来这个切面根据入参Base64编码+完整方法名判断redis值是否存在,
* 存在则拦截防刷,不存在则允许调用;
*/
@Aspect
@Component
@Slf4j
public class PreventAop {
@Autowired
private RedisUtil redisUtil;
/**
* 切入点
*/
@Pointcut("@annotation(com.ung.controllertest.annotation.Prevent)")
public void pointcut() {
}
/**
* 处理前
*
* @return
*/
@Before("pointcut()")
public void joinPoint(JoinPoint joinPoint) throws Exception {
//获取ip
HttpServletRequest request = (HttpServletRequest) joinPoint.getArgs()[0];
String ip = getIp(request);
System.out.println("请求的:" + ip);
//获取请求参数
String requestStr = null;
//如果入参个数 !=1
if (joinPoint.getArgs().length != 1) {
for (int i = 1; i < joinPoint.getArgs().length; i++) {
requestStr += JSON.toJSONString(joinPoint.getArgs()[i]);
}
if (StringUtils.isEmpty(requestStr) || requestStr.equalsIgnoreCase("{}")) {
throw new BusinessException("[防刷]入参不允许为空");
}
}
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
Method method = joinPoint.getTarget().getClass().getMethod(methodSignature.getName(),
methodSignature.getParameterTypes());
Prevent preventAnnotation = method.getAnnotation(Prevent.class);
String methodFullName = method.getDeclaringClass().getName() + method.getName();
entrance(preventAnnotation, ip + requestStr, methodFullName);
return;
}
public String getIp(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 ||
"unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
/**
* 入口
*
* @param prevent
* @param requestStr
*/
private void entrance(Prevent prevent, String requestStr, String methodFullName) throws Exception {
PreventStrategyEnum strategy = prevent.strategy();
switch (strategy) {
case DEFAULT:
defaultHandle(requestStr, prevent, methodFullName);
break;
default:
throw new BusinessException("无效的策略");
}
}
/**
* 默认处理方式
*
* @param requestStr
* @param prevent
*/
private void defaultHandle(String requestStr, Prevent prevent, String methodFullName) throws Exception {
String base64Str = toBase64String(requestStr);
long expire = Long.parseLong(prevent.value());
String resp = redisUtil.get(methodFullName + base64Str);
if (StringUtils.isEmpty(resp)) {
redisUtil.set(methodFullName + base64Str, requestStr, expire);
} else {
String message = !StringUtils.isEmpty(prevent.message()) ? prevent.message() :
expire + "秒内不允许重复请求";
throw new BusinessException(message);
}
}
/**
* 对象转换为base64字符串
*
* @param obj 对象值
* @return base64字符串
*/
private String toBase64String(String obj) throws Exception {
if (StringUtils.isEmpty(obj)) {
return null;
}
Base64.Encoder encoder = Base64.getEncoder();
byte[] bytes = obj.getBytes("UTF-8");
return encoder.encodeToString(bytes);
}
}
controller接口使用注解
package com.ung.controllertest.controller;
import com.ung.controllertest.annotation.Prevent;
import com.ung.controllertest.pojo.dto.UserDto;
import com.ung.controllertest.service.UserService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import javax.servlet.http.HttpServletRequest;
import java.util.List;
@RestController
public class UserController {
private UserService userService;
@Autowired
public UserController(UserService userService) {
this.userService = userService;
}
@GetMapping("/user/getById/{id}")
public UserDto getById(@PathVariable("id") Long id) {
return userService.getById(id);
}
@GetMapping("/user/getList")
@Prevent(value = "10",message = "10秒内不能重复请求")
public List<UserDto> getList(HttpServletRequest request) {
return userService.getList();
}
@PostMapping("/user/addUser")
@Prevent(value = "10",message = "10秒内不能重复请求")
public boolean addUser(HttpServletRequest request, @RequestBody UserDto userDto) {
return userService.addUser(userDto);
}
@GetMapping("/user/getUserName/{id}")
public String getUserName(@PathVariable("id") Long id) {
return userService.getUserName(id);
}
}
最后请求接口测试
http://localhost:8001/user/addUser
{
"username": "username",
"password": "123456",
"email": "123@qq.com",
"phone": "13900445200"
}
连续请求会拒绝
对没有请求参数的get请求也是一样的有效
http://localhost:8001/user/getList