原创

限流算法之令牌桶

温馨提示:
本文最后更新于 2023年10月17日,已超过 463 天没有更新。若文章内的图片失效(无法正常加载),请留言反馈或直接联系我

概要

令牌桶算法是一种非常经典的流量控制算法,它以其简单而有效的特性被广泛应用于网络流量控制、速率限制等领域。

常见的限流算法有:计数器、时间窗口、滑动时间窗口、漏桶、令牌桶等等。

本文主要介绍令牌桶(Token Bucket)算法的原理及网关(Spring Cloud Gateway)限流场景的实际应用实现。首先需要对网关有一定了解。

原理

令牌桶算法最早由RFC 2698提出,它通过在系统中引入一个“令牌桶”的概念来实现流量控制。令牌桶是一个容器,它以固定的速率产生令牌,并将令牌放入桶中,当桶内令牌数达到桶的最大容量后,新生成的令牌会被丢弃。

当有请求到来时,先从桶内获取令牌,令牌足够则对请求进行放行,若无令牌则拒绝或延迟本次请求。

原理

实现

本文采用lua脚本实现网关令牌桶限流,并通过自定义配置实现令牌预支功能。

用redis对所有服务所有节点进行统一流量处理,基于redis高性能内存数据库特性,具有快速读写和扩展性的优势。而redis单线程执行lua脚本的特性,可有效避免并发冲突问题,确保流量处理的正确性和可靠性。

lua脚本

定义key为缓存key,可根据业务情况自定义限流力度,如:服务力度、接口力度、ip+接口力度等等。

并传入桶容量、固定生成令牌量、最大可预支令牌(以时间计算)、单次请求消耗令牌量(可根据不同维度控制令牌消耗对不同业务进行优先级和流量控制)等进行计算当前请求是否合法。

local key = 'token_bucket_key_' .. KEYS[1]

local getTokens = tonumber(ARGV[1])

local timeout = tonumber(ARGV[2] or -1)
--当前时间
local nowTime = tonumber(ARGV[3])
--初始最大值
local initMaxTokens = tonumber(ARGV[4])
--初始每秒令牌数
local initGenerateTokens = tonumber(ARGV[5])

local hasTokens = tonumber(redis.call('hget', key, 'hasTokens'))

local maxTokens = tonumber(redis.call('hget', key, 'maxTokens'))

local generateTokens = tonumber(redis.call('hget', key, 'generateTokens'))

local preTime = tonumber(redis.call('hget', key, 'preTime'))

local result = redis.call('exists', key)

if result == 0 then
    redis.call('hmset', key, 'hasTokens', initMaxTokens, 'maxTokens', initMaxTokens, 'generateTokens', initGenerateTokens, 'preTime', nowTime)
    redis.call('PEXPIRE', key, 2000)
    generateTokens = initGenerateTokens
    maxTokens = initMaxTokens
    hasTokens = initMaxTokens
    preTime = nowTime
end

-- 生成单令牌需要的时间
local generateTokenTime = 1000 / generateTokens

if timeout ~= -1 then
    if timeout < preTime - nowTime then
        return -1
    end
end

if nowTime > preTime then
    -- 获取上一次取令牌时间到当前时间的时间差
    local hasTime = nowTime - preTime
    -- 时间间隔生成的令牌数量
    local createTokens = hasTime / generateTokenTime
    -- 原有 + 新生 与 桶容量 比较取最小
    hasTokens = math.min(hasTokens + createTokens, maxTokens)
    preTime = nowTime
end

-- 获取可取令牌量  现有和支取量取最小
local canGetTokens = math.min(hasTokens, getTokens)

-- 预支量
local advanceTokens = getTokens - canGetTokens

-- 预支量消耗的时间
local advanceTime = advanceTokens * generateTokenTime

if timeout ~= -1 then
    -- 如果预支消耗时间大于可等待时间则说明预期等待时间内无法等到令牌
    if timeout < preTime + advanceTime - nowTime then
        return -1
    end
end

hasTokens = hasTokens - canGetTokens
advanceTime = math.floor(advanceTime)
redis.call('hmset', key, 'hasTokens', hasTokens, 'preTime', preTime + advanceTime)
redis.call('PEXPIRE', key, 2000 + preTime + advanceTime - nowTime)

return preTime - nowTime

调用执行lua脚本

执行lua脚本并对执行结果进行判断当前请求的合法性;

@Component
public class TokenBucket implements Serializable {
    private static final long serialVersionUID = 8249211818975021021L;

    @Autowired
    private StringRedisTemplate redisTemplate;

    public boolean getTokens(String key, String maxTokens, String generateTokens, String tokens, Integer timeout, TimeUnit unit) {
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/getToken.lua")));
        // 指定返回类型
        redisScript.setResultType(Long.class);
        Long waitTime = redisTemplate.execute(redisScript, Collections.singletonList(key), tokens, String.valueOf(unit.toMillis(timeout)),
                String.valueOf(System.currentTimeMillis()), maxTokens, generateTokens);
        if (Objects.isNull(waitTime) || waitTime == -1) {
            return false;
        }
        return true;
    }
}

自定义限流过滤器

Spring Cloud Gateway中,我没可以继承AbstractGatewayFilterFactory来实现自定义限流策略。类名称需要满足XXXGatewayFilterFactory,且XXX与配置文件保持一致。如本文TokenLimitGatewayFilterFactory配置文件须配置为TokenLimit方可识别。

Config配置类用于配置限流策略相关参数;

重写apply方法实现自定义限流逻辑;如本文中使用了开关和白名单来启用或关闭限流;

@Component
public class TokenLimitGatewayFilterFactory extends AbstractGatewayFilterFactory<TokenLimitGatewayFilterFactory.Config> {
    public static final String PREFIX_KEY = "key";
    public static final String GENERATE_TOKENS = "generateTokens";
    public static final String MAX_TOKENS = "maxTokens";
    public static final String TIMEOUT = "timeout";
    @Autowired
    private TokenBucket tokenBucket;
    @Value("#{'${ip.white-list:,}'.split(',')}")
    private List<String> ips;
    @Value("${filter.switch:true}")
    private Boolean filterSwitch;

    @Override
    public GatewayFilter apply(Config config) {
        return (exchange, chain) -> {
            ServerHttpRequest request = exchange.getRequest();
            String ip = request.getRemoteAddress().getAddress().getHostAddress();
            String url = request.getPath().toString();
            String key = config.getKey() + ip + "_" + url;
            if (filterSwitch && !ips.contains(ip)){
                Integer generateTokens = config.getGenerateTokens();
                Integer maxTokens = config.getMaxTokens();
                Integer timeout = config.getTimeout();
                if (Objects.isNull(timeout)){
                    timeout = 0;
                }
                boolean hasToken = tokenBucket.getTokens(key, String.valueOf(maxTokens), String.valueOf(generateTokens), "1", timeout, TimeUnit.MILLISECONDS);
                if (hasToken){
                    return chain.filter(exchange);
                }

                ServerHttpResponse response = exchange.getResponse();
                response.setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
                response.getHeaders().put("Content-Type", Collections.singletonList("application/json;charset=utf-8"));
                response.getHeaders().put("Access-Control-Allow-Origin", Collections.singletonList("*"));
                DataBuffer wrap = response.bufferFactory().wrap("".getBytes(StandardCharsets.UTF_8));
                return response.writeWith(Mono.just(wrap));
            }
            return chain.filter(exchange);
        };
    }

    public TokenLimitGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public List<String> shortcutFieldOrder() {
        return Arrays.asList(PREFIX_KEY, GENERATE_TOKENS, MAX_TOKENS, TIMEOUT);
    }

    public static class Config {
        @NotEmpty(message = "key不能为空")
        private String key;
        @Min(value = 1, message = "桶每秒补充数量最小为1")
        @NotNull(message = "桶每秒补充数量不能为空")
        private Integer generateTokens;
        @Min(value = 1, message = "桶最大数量最小为1")
        @NotNull(message = "桶最大数量不能为空")
        private Integer maxTokens;
        @Min(value = 0, message = "超时时间最小为0")
        private Integer timeout;

        public String getKey() {
            return key;
        }

        public void setKey(String key) {
            this.key = key;
        }

        public Integer getGenerateTokens() {
            return generateTokens;
        }

        public void setGenerateTokens(Integer generateTokens) {
            this.generateTokens = generateTokens;
        }

        public Integer getMaxTokens() {
            return maxTokens;
        }

        public void setMaxTokens(Integer maxTokens) {
            this.maxTokens = maxTokens;
        }

        public Integer getTimeout() {
            return timeout;
        }

        public void setTimeout(Integer timeout) {
            this.timeout = timeout;
        }
    }
}

配置文件

配置文件中配置自定义限流策略。

spring:
  cloud:
    gateway:
      routes:
        - id: xxx-service
          uri: lb://xxx
          order: 0
          predicates:
            - Path=/gateway/xxx/**
          filters:
            - StripPrefix=1
            - TokenLimit=xxx_, 10, 15, 500

# ip白名单
ip:
  white-list: 127.0.0.1

# 开关
filter:
  switch: true
正文到此结束
本文目录