上篇介绍了使用SpringCloudGateway如何在内存中进行限流操作。
但是,SpringCloudGateway默认是返回HttpStatus.TOO_MANY_REQUESTS 429状态,可是我们并不希望用户看到这个不友好的提示,而是希望用户看到我们的自定义界面,更好看,或者去一个游戏界面让用户玩玩游戏吧。

有一个解决方案,一般在api网关的前面还会部署一个nginx服务,用于网关的负载均衡,那么可以在nginx收到429响应时,转向特定页面进行展示。
但是如果我们希望根据不同的服务,转向不同的页面,这样就需要维护2个地方:网关和nginx,增加了管理成本。

所以,继续研究……
首先查了一下《官方文档》,没有找到可供修改的地方。

接着,启动项目,跟踪了一下代码,发现项目读取配置后,是把配置绑定到RequestRateLimiterGatewayFilterFactory$Config类:

public static class Config implements HasRouteId {
	private KeyResolver keyResolver;
	private RateLimiter rateLimiter;
	private HttpStatus statusCode = HttpStatus.TOO_MANY_REQUESTS;
	private Boolean denyEmptyKey;
	private String emptyKeyStatus;
	private String routeId;

看到了吧?这些属性就是我们在上篇的application.yml里的配置。
里面有一个statusCode,默认赋值为HttpStatus.TOO_MANY_REQUESTS了,
可不可以在配置里,对它进行配置,改成302呢?修改一下配置文件如下:

filters:
  - name: RequestRateLimiter
    args: # 对应 RequestRateLimiterGatewayFilterFactory$Config类的属性
      key-resolver: "#{@myKeyResolver}"
      status-code: FOUND 

保存并启动项目,频繁刷新,果然不出429,而是出空白页面了。
抓包看了一下,响应的HTTP Status已经变成了302,但是因为Header里缺少Location,所以没有产生跳转。

接着,修改前文里的MyRateLimiter类,出现限流时,增加Location的Header:

if (probe.isConsumed()) {
    // 拿到令牌,允许进入
    return Mono.just(new Response(true, headers));
} else {
    // 没令牌了,默认返回429,不允许进入
    headers.put("Location", "https://www.baidu.com/");
    return Mono.just(new Response(false, headers));
}

启动项目,多刷新几次,页面果然跳到百度去了,至此302改造完成。


实际的项目中,大多使用前后端分离的开发模式,而限流一般也是针对后端API接口,所以如果限流返回302,一般是没有意义的,而是希望返回json格式的错误信息,那要怎么办呢?

跟踪了一下代码,发现限流最终是在RequestRateLimiterGatewayFilterFactory.apply(Config)方法里进行状态值的设置和返回。
RequestRateLimiterGatewayFilterFactory类实例又是怎么得到的呢?
是根据我们在application.yml里配置的filters.name: RequestRateLimiter,在RouteDefinitionRouteLocator类里:

List<GatewayFilter> loadGatewayFilters(String id, List<FilterDefinition> filterDefinitions) {
	ArrayList<GatewayFilter> ordered = new ArrayList<>(filterDefinitions.size());
	for (int i = 0; i < filterDefinitions.size(); i++) {
		FilterDefinition definition = filterDefinitions.get(i);
		GatewayFilterFactory factory = this.gatewayFilterFactories.get(definition.getName());

上面代码里definition.getName就是配置的name,gatewayFilterFactories是一个Map,Key是name,value是GatewayFilter类型的Bean数组,默认情况下,会有一个如下的映射:
RequestRateLimiter -> RequestRateLimiterGatewayFilterFactory
这个映射关系怎么来的??
再看代码,是在RouteDefinitionRouteLocator类的构造函数里赋值的:
gatewayFilterFactories.forEach(factory -> this.gatewayFilterFactories.put(factory.name(), factory));
这个factory.name()的实现如下:

default String name() {
	return NameUtils.normalizeFilterFactoryName(getClass());
}

public final class NameUtils {
	public static String normalizeFilterFactoryName(Class<? extends GatewayFilterFactory> clazz) {
		return removeGarbage(clazz.getSimpleName().replace(GatewayFilterFactory.class.getSimpleName(), ""));
	}

上面代码里的GatewayFilterFactory.class.getSimpleName(),结果是 GatewayFilterFactory
所以,在生成Bean: RequestRateLimiterGatewayFilterFactory 后,把它替换为RequestRateLimiter作为key,放到Map: gatewayFilterFactories 里。

综上,所以,如果我们直接定义一个自己的Bean,继承RequestRateLimiterGatewayFilterFactory,是不会生效的,一定要在配置里使用自定义的Factory的名字,比如我们新建一个类叫: AbcRequestRateLimiterGatewayFilterFactory
那么在配置里,就要使用如下name(否则你的类是没用处的):

filters:
  - name: AbcRequestRateLimiter
    args:
      key-resolver: "#{@myKeyResolver}"
      #status-code: FOUND 

好了,配置好了,那要实现我们自己的这个AbcRequestRateLimiterGatewayFilterFactory了,参考代码:

import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.RequestRateLimiterGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.ratelimit.KeyResolver;
import org.springframework.cloud.gateway.filter.ratelimit.RateLimiter;
import org.springframework.cloud.gateway.route.Route;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.server.reactive.ServerHttpResponse;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.util.Map;

@Slf4j
public class AbcRequestRateLimiterGatewayFilterFactory extends RequestRateLimiterGatewayFilterFactory {

    private final RateLimiter defaultRateLimiter;

    private final KeyResolver defaultKeyResolver;

    public AbcRequestRateLimiterGatewayFilterFactory(RateLimiter defaultRateLimiter, KeyResolver defaultKeyResolver) {
        super(defaultRateLimiter, defaultKeyResolver);
        this.defaultRateLimiter = defaultRateLimiter;
        this.defaultKeyResolver = defaultKeyResolver;
    }

    @Override
    public GatewayFilter apply(Config config) {
        KeyResolver resolver = getOrDefault(config.getKeyResolver(), defaultKeyResolver);
        RateLimiter<Object> limiter = getOrDefault(config.getRateLimiter(), defaultRateLimiter);
        return (exchange, chain) -> resolver.resolve(exchange).flatMap(key -> {
//            if (EMPTY_KEY.equals(key)) {
//                if (denyEmpty) {
//                    setResponseStatus(exchange, emptyKeyStatus);
//                    return exchange.getResponse().setComplete();
//                }
//                return chain.filter(exchange);
//            }
            String routeId = config.getRouteId();
            if (routeId == null) {
                Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
                routeId = route.getId();
            }

            String finalRouteId = routeId;
            return limiter.isAllowed(routeId, key).flatMap(response -> {

                for (Map.Entry<String, String> header : response.getHeaders().entrySet()) {
                    exchange.getResponse().getHeaders().add(header.getKey(), header.getValue());
                }

                if (response.isAllowed()) {
                    return chain.filter(exchange);
                }

                log.warn("已限流: {}", finalRouteId);
                ServerHttpResponse httpResponse = exchange.getResponse();
                httpResponse.setStatusCode(config.getStatusCode());
                if (!httpResponse.getHeaders().containsKey("Content-Type")) {
                    httpResponse.getHeaders().add("Content-Type", "application/json");
                }
                DataBuffer buffer = httpResponse.bufferFactory().wrap("{'msg':'访问已受限制,请稍候重试'}".getBytes(StandardCharsets.UTF_8));
                return httpResponse.writeWith(Mono.just(buffer));

                // return exchange.getResponse().setComplete();
            });
        });
    }

    private <T> T getOrDefault(T configValue, T defaultValue) {
        return (configValue != null) ? configValue : defaultValue;
    }
}
Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐