Redis LUA RateLimiter

参考文章:基于Redis的限流系统的设计
本文作为总结和回顾

限流脚本

首先需要引入的是redis的LUA脚本.


--- 判断source_str 中是否contains pattern_str
--- @param source_str
--- @param patter_str
local function contains(source_str, sub_str)
    local start_pos, end_pos = string.find(source_str, sub_str);
    if start_pos == nil then
        return false;
    end
    local source_str_len = string.len(source_str);

    if source_str_len == end_pos then
        return true
    elseif string.sub(source_str, end_pos + 1, end_pos + 1) == "," then
        return true
    end
    return false;
end


--- 获取令牌
--- 返回码
--- 0 没有令牌桶配置
--- -1 表示取令牌失败,也就是桶里没有令牌
--- 1 表示取令牌成功
--- @param key 令牌的唯一标识
--- @param permits  请求令牌数量
--- @param curr_mill_second 当前毫秒数
--- @param context 使用令牌的应用标识
local function acquire(key, permits, curr_mill_second, context)
    local rate_limit_info = redis.pcall("HMGET", key, "last_mill_second", "curr_permits", "max_permits", "rate", "apps")
    local last_mill_second = rate_limit_info[1]
    local curr_permits = tonumber(rate_limit_info[2])
    local max_permits = tonumber(rate_limit_info[3])
    local rate = rate_limit_info[4]
    local apps = rate_limit_info[5]

    --- 标识没有配置令牌桶
    if type(apps) == 'boolean' or apps == nil or not contains(apps, context) then
        return 0
    end


    local local_curr_permits = max_permits;

    --- 第一次, 没有last_mill_second, 所以local_curr_permits = max_permits, 
    --- 首先设置 redis.pcall("HSET", key, "last_mill_second", curr_mill_second) 将上次更新时间修改为现在
    --- 然后直接进入到最下面 redis.pcall("HSET", key, "curr_permits", local_curr_permits - permits) 用最大速率-自己用的这一次
    
    --- 第二次, 有last_mill_second, 走进去if的逻辑
    --- reverse_permits -> 时间差转换成秒, 然后乘以每秒的速率rage -> 得到每秒需要添加多少个令牌
    
    --- 第二次

    --- 令牌桶刚刚创建,上一次获取令牌的毫秒数为空
    --- 根据和上一次向桶里添加令牌的时间和当前时间差,触发式往桶里添加令牌,并且更新上一次向桶里添加令牌的时间
    --- 如果向桶里添加的令牌数不足一个,则不更新上一次向桶里添加令牌的时间
    if (type(last_mill_second) ~= 'boolean'  and last_mill_second ~= nil) then
        --- 当前消耗时间内 -> 需要添加多少个令牌
        --- 假设curr_mill_second - last_mill_second = 100ms,rate=10, 100ms/1000=0.1s * 10 = 1个
        --- 假设curr_mill_second - last_mill_second = 1000000ms,rate=10, 1000000ms/1000=1000s * 10 = 1w个 -》 说明很久没有访问了,下面math.min就会丢弃掉多的
        --- 即刚过去的这一段时间, 需要往桶里面添加1个令牌  
        local reverse_permits = math.floor(((curr_mill_second - last_mill_second) / 1000) * rate)
        --- 需要+的 + 当前还剩余的令牌 = 期望当前的令牌数量
        local expect_curr_permits = reverse_permits + curr_permits;
        --- 将期望的和最大速率比对, 取小的, 防止超载
        --- 多于最大速率后的漏(丢弃策略)
        local_curr_permits = math.min(expect_curr_permits, max_permits);

        --- 大于0表示不是第一次获取令牌,也没有向桶里添加令牌
        --- 如果当前消耗时间内需要添加令牌, 设置最新的添加时间为当前时间 
        if (reverse_permits > 0) then
            redis.pcall("HSET", key, "last_mill_second", curr_mill_second)
        end
    else
        redis.pcall("HSET", key, "last_mill_second", curr_mill_second)
    end


    local result = -1
    if (local_curr_permits - permits >= 0) then
        result = 1
        redis.pcall("HSET", key, "curr_permits", local_curr_permits - permits)
    else
        redis.pcall("HSET", key, "curr_permits", local_curr_permits)
    end

    return result
end


--- 初始化令牌桶配置
--- @param key 令牌的唯一标识
--- @param max_permits 桶大小
--- @param rate  向桶里添加令牌的速率
--- @param apps  可以使用令牌桶的应用列表,应用之前用逗号分隔
local function init(key, max_permits, rate, apps)
    local rate_limit_info = redis.pcall("HMGET", key, "last_mill_second", "curr_permits", "max_permits", "rate", "apps")
    local org_max_permits = tonumber(rate_limit_info[3])
    local org_rate = rate_limit_info[4]
    local org_apps = rate_limit_info[5]

    if (org_max_permits == nil) or (apps ~= org_apps or rate ~= org_rate or max_permits ~= org_max_permits) then
        redis.pcall("HMSET", key, "max_permits", max_permits, "rate", rate, "curr_permits", max_permits, "apps", apps)
    end
    return 1;
end


--- 删除令牌桶
local function delete(key)
    redis.pcall("DEL", key)
    return 1;
end


local key = KEYS[1]
local method = ARGV[1]

if method == 'acquire' then
    return acquire(key, ARGV[2], ARGV[3], ARGV[4])
elseif method == 'init' then
    return init(key, ARGV[2], ARGV[3], ARGV[4])
elseif method == 'delete' then
    return delete(key)
else
    --ignore
end


概览

  1. LUA的下标是从1开始的
  2. LUA的KEYS是调用方逗号分割前一部分,ARGV是后一部分。

整个脚本包含了4个方法,入口处使用ARGV[1]判断,决定调用哪一个方法

其核心使用了一个 HASH 来存放策略信息。其中包括:

  • max_permits -> 每秒最大速率
  • rate -> 每秒放入速率
  • curr_permits -> 当前速率
  • apps -> 应用列表
  • last_mill_second -> 上次添加令牌时间

delete

delete方法最为简单,就是调用DEL删除key即可,相当于删除了访问控制的策略。

init

初始化方法即将数据库或其他三方的元数据,加载到redis中,以redis的数据格式 Hash 存储。
此方法是幂等的,在加入之前会进行判断,如果没有才会添加。

acquire

此方法最为核心,是从桶中获取令牌的动作,并且模拟RateLimiter实现了触发式添加,从而提升QPS。

方法前面内容均为赋值操作和基本判空操作,没什么好说的。
local local_curr_permits = max_permits;便开始了触发式添加动作。
if (local_curr_permits - permits >= 0) then开始消耗动作。

--- 获取令牌
--- 返回码
--- 0 没有令牌桶配置
--- -1 表示取令牌失败,也就是桶里没有令牌
--- 1 表示取令牌成功
--- @param key 令牌的唯一标识
--- @param permits  请求令牌数量
--- @param curr_mill_second 当前毫秒数
--- @param context 使用令牌的应用标识
local function acquire(key, permits, curr_mill_second, context)
    local rate_limit_info = redis.pcall("HMGET", key, "last_mill_second", "curr_permits", "max_permits", "rate", "apps")
    local last_mill_second = rate_limit_info[1]
    local curr_permits = tonumber(rate_limit_info[2])
    local max_permits = tonumber(rate_limit_info[3])
    local rate = rate_limit_info[4]
    local apps = rate_limit_info[5]

    if type(apps) == 'boolean' or apps == nil or not contains(apps, context) then
        return 0
    end

    --- 将当前令牌调整为配置的最大速率
    local local_curr_permits = max_permits;

    --- 此处为触发式添加令牌动作
    --- 判断是否有上次添加令牌时间, 此处有三种情况:
    --- 1. 第一次进来, last_mill_second = nil, 直接走进else, 然后消耗令牌
    --- 2. 第二次进来, 分为两种情况
    --- 2.1 在第一次后的1秒内访问, 即下方情况1
    --- 2.2 在第一次后的1秒后访问, 即下方情况2
    if (type(last_mill_second) ~= 'boolean'  and last_mill_second ~= nil) then
        --- 向下取整(距离上次添加过去多少秒 * rate) = 需要向令牌桶添加的令牌数量
        --- 假设rate=10, max_permits=10
        --- 情况1: 距离上次添加过去多少秒 < 1秒 假设=0.2秒 * 10 = 2个令牌
        --- 情况2: 距离上次添加过去多少秒 > 1秒 假设=10秒 * 10 = 100个令牌
        local reverse_permits = math.floor(((curr_mill_second - last_mill_second) / 1000) * rate)
        --- 本次需要添加的令牌 + 剩余的令牌 = 期望令牌数量
        --- 情况1: 可能出现两种情况
        --- 情况1.1 reverse_permits = 2, curr_permits >= 9 , 即桶中还只消耗1个或没有消耗, 但是本次需要添加2个, 多了, 下方min函数会将其漏掉
        --- 情况1.2 reverse_permits = 2, curr_permits < 9 , 即桶中消耗了超过2个, 本次添加2个, 符合
        --- 情况2: 绝对多了, 下方min函数会将其漏掉
        local expect_curr_permits = reverse_permits + curr_permits;
        --- 此处解决上面两种情况, 防止过载(即将多的令牌丢弃的动作)
        --- 情况1: 期望令牌 < max_permits(10), 设置期望令牌到当前令牌
        --- 情况2: 期望令牌 > max_permits(10), 设置最大令牌到当前令牌 -> 即很久没有访问后的一次访问动作
        local_curr_permits = math.min(expect_curr_permits, max_permits);

        --- 如果需要向令牌桶添加的令牌数量 > 0, 就更新上次添加令牌时间
        if (reverse_permits > 0) then
            redis.pcall("HSET", key, "last_mill_second", curr_mill_second)
        end
    else
        redis.pcall("HSET", key, "last_mill_second", curr_mill_second)
    end


    local result = -1
    --- 当前令牌 够本次 申请的令牌数量
    if (local_curr_permits - permits >= 0) then
        result = 1
        --- 消耗令牌
        redis.pcall("HSET", key, "curr_permits", local_curr_permits - permits)
    else
        --- 上面的if不满足, 即令牌没有了
        redis.pcall("HSET", key, "curr_permits", local_curr_permits)
    end

    return result
end

Java客户端编写

载入脚本

整个脚本较大,可以事先载入到redis服务器,采用SHA1串访问。
多个实例创建多份rateLimiterLua也无妨,因为是同一个脚本文件内容,创建出来的sha1串也是一致的。

@Bean
public DefaultRedisScript<Long> rateLimiterLua() {
    DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
    redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/rate_limiter.lua")));
    redisScript.setResultType(Long.class);
    return redisScript;
}

编写Java访问客户端

  1. init() - 初始化访问策略
  2. acquire() - 获取令牌
  3. delete() - 删除访问策略
public class RateLimiterClient {

    @Autowired
    @Qualifier("longRedisTemplate")
    private RedisTemplate<String, Long> longRedisTemplate;
    @Autowired
    private RedisTemplate<String, String> redisTemplate;

    @Resource
    private RedisScript<Long> rateLimiterLua;

    /**
     * 获取访问令牌
     * @param context 应用名称
     * @param key 限制速率key
     * @param permits 获取令牌数量
     * @return token
     */
    public Token acquireToken(String context, String key, Integer permits) {

        Token token;
        try {
            // redis当前时间
            Long currMillSecond = longRedisTemplate.execute(RedisServerCommands::time);
            // 获取令牌
            Long acquire = longRedisTemplate.execute(rateLimiterLua, ImmutableList.of(getKey(key)), RATE_LIMITER_ACQUIRE_METHOD, permits.toString(), currMillSecond, context);

            if (acquire == null) {
                log.error("no rate limit config for context = {}", context);
                return Token.NO_CONFIG;
            }

            if (acquire == 1) {
                token = Token.PASS;
            } else if (acquire == -1) {
                token = Token.FUSING;
            } else {
                log.error("no rate limit config for context = {}", context);
                token = Token.NO_CONFIG;
            }
        } catch (Exception e) {
            log.error("get rage limit token for redis error, key = " + key, e);
            token = Token.ACCESS_REDIS_FAIL;
        }
        return token;
    }

    public void deleteRateLimiter(String key) {
        redisTemplate.execute(rateLimiterLua, ImmutableList.of(getKey(key)), RATE_LIMITER_DELETE_METHOD);
    }

    public void initRateLimiter(String code, Integer maxPermits, Integer rate, String apps) {
        redisTemplate.execute(rateLimiterLua,
                ImmutableList.of(getKey(code)),
                RATE_LIMITER_INIT_METHOD,
                maxPermits.toString(),
                rate.toString(),
                apps);
    }
    
    private String getKey(String key) {
        return RateLimiterConstants.RATE_LIMITER_KEY_PREFIX + key;
    }

}