diff --git a/README.md b/README.md index 954f254..5181f23 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,30 @@ ratelimit = Ratelimit( ) ``` +# Custom Rates + +When rate limiting, you may want different requests to consume different amounts of tokens. +This could be useful when processing batches of requests where you want to rate limit based +on items in the batch or when you want to rate limit based on the number of tokens. + +To achieve this, you can simply pass `rate` parameter when calling the limit method: + +```python + +from upstash_ratelimit import Ratelimit, FixedWindow +from upstash_redis import Redis + +ratelimit = Ratelimit( + redis=Redis.from_env(), + limiter=FixedWindow(max_requests=10, window=10), +) + +# pass rate as 5 to subtract 5 from the number of +# allowed requests in the window: +identifier = "api" +response = ratelimit.limit(identifier, rate=5) +``` + # Contributing ## Preparing the environment diff --git a/tests/test_fixed_window.py b/tests/test_fixed_window.py index 6c6c870..636e586 100644 --- a/tests/test_fixed_window.py +++ b/tests/test_fixed_window.py @@ -84,3 +84,20 @@ def test_get_reset(redis: Redis) -> None: with patch("time.time", return_value=1688910786.167): assert ratelimit.get_reset(random_id()) == approx(1688910790.0) + + +def test_custom_rate(redis: Redis) -> None: + ratelimit = Ratelimit( + redis=redis, + limiter=FixedWindow(max_requests=10, window=1, unit="d"), + ) + rate = 2 + + id = random_id() + + ratelimit.limit(id) + ratelimit.limit(id, rate) + assert ratelimit.get_remaining(id) == 7 + + ratelimit.limit(id, rate) + assert ratelimit.get_remaining(id) == 5 diff --git a/tests/test_sliding_window.py b/tests/test_sliding_window.py index 40b6127..4906666 100644 --- a/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -130,3 +130,20 @@ def test_get_reset(redis: Redis) -> None: with patch("time.time", return_value=1688910786.167): assert ratelimit.get_reset(random_id()) == approx(1688910790.0) + + +def test_custom_rate(redis: Redis) -> None: + ratelimit = Ratelimit( + redis=redis, + limiter=SlidingWindow(max_requests=10, window=5), + ) + rate = 2 + + id = random_id() + + ratelimit.limit(id) + ratelimit.limit(id, rate) + assert ratelimit.get_remaining(id) == 7 + + ratelimit.limit(id, rate) + assert ratelimit.get_remaining(id) == 5 diff --git a/tests/test_token_bucket.py b/tests/test_token_bucket.py index 739c546..56a80f4 100644 --- a/tests/test_token_bucket.py +++ b/tests/test_token_bucket.py @@ -146,3 +146,20 @@ def test_get_reset_with_refills_that_should_be_made(redis: Redis) -> None: time.sleep(3) assert ratelimit.get_reset(id) >= last_reset + 2 + + +def test_custom_rate(redis: Redis) -> None: + ratelimit = Ratelimit( + redis=redis, + limiter=TokenBucket(max_tokens=10, refill_rate=1, interval=1), + ) + rate = 2 + + id = random_id() + + ratelimit.limit(id) + ratelimit.limit(id, rate) + assert ratelimit.get_remaining(id) == 7 + + ratelimit.limit(id, rate) + assert ratelimit.get_remaining(id) == 5 diff --git a/upstash_ratelimit/asyncio/ratelimit.py b/upstash_ratelimit/asyncio/ratelimit.py index 4a89f01..c0f245e 100644 --- a/upstash_ratelimit/asyncio/ratelimit.py +++ b/upstash_ratelimit/asyncio/ratelimit.py @@ -32,7 +32,7 @@ def __init__( self._limiter = limiter self._prefix = prefix - async def limit(self, identifier: str) -> Response: + async def limit(self, identifier: str, rate: int = 1) -> Response: """ Determines if a request should pass or be rejected based on the identifier and previously chosen ratelimit. @@ -60,12 +60,14 @@ async def main() -> None: :param identifier: Identifier to ratelimit. Use a constant string to \ limit all requests, or user ids, API keys, or IP addresses for \ individual limits. + :param rate: Rate with which to subtract from the limit of the \ + identifier. """ key = f"{self._prefix}:{identifier}" - return await self._limiter.limit_async(self._redis, key) + return await self._limiter.limit_async(self._redis, key, rate) - async def block_until_ready(self, identifier: str, timeout: float) -> Response: + async def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response: """ Blocks until the request may pass or timeout is reached. @@ -97,6 +99,8 @@ async def main() -> None: individual limits. :param timeout: Maximum time in seconds to wait until the request \ may pass. + :param rate: Rate with which to subtract from the limit of the \ + identifier. """ if timeout <= 0: @@ -106,7 +110,7 @@ async def main() -> None: deadline = now_s() + timeout while True: - response = await self.limit(identifier) + response = await self.limit(identifier, rate) if response.allowed: break diff --git a/upstash_ratelimit/limiter.py b/upstash_ratelimit/limiter.py index d2e7d41..f22832e 100644 --- a/upstash_ratelimit/limiter.py +++ b/upstash_ratelimit/limiter.py @@ -35,11 +35,11 @@ class Response: class Limiter(abc.ABC): @abc.abstractmethod - def limit(self, redis: Redis, identifier: str) -> Response: + def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response: pass @abc.abstractmethod - async def limit_async(self, redis: AsyncRedis, identifier: str) -> Response: + async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response: pass @abc.abstractmethod @@ -108,16 +108,16 @@ async def _with_at_most_one_request_async( class AbstractLimiter(Limiter): @abc.abstractmethod - def _limit(self, identifier: str) -> Generator: + def _limit(self, identifier: str, rate: int = 1) -> Generator: pass - def limit(self, redis: Redis, identifier: str) -> Response: - response: Response = _with_at_most_one_request(redis, self._limit(identifier)) + def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response: + response: Response = _with_at_most_one_request(redis, self._limit(identifier, rate)) return response - async def limit_async(self, redis: AsyncRedis, identifier: str) -> Response: + async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response: response: Response = await _with_at_most_one_request_async( - redis, self._limit(identifier) + redis, self._limit(identifier, rate) ) return response @@ -171,17 +171,18 @@ class FixedWindow(AbstractLimiter): """ SCRIPT = """ - local key = KEYS[1] -- identifier including prefixes - local window = ARGV[1] -- interval in milliseconds - - local num_requests = redis.call("INCR", key) - if num_requests == 1 then - -- The first time this key is set, the value will be 1. - -- So we only need the expire command once - redis.call("PEXPIRE", key, window) + local key = KEYS[1] + local window = ARGV[1] + local increment_by = ARGV[2] -- increment rate per request at a given value, default is 1 + + local r = redis.call("INCRBY", key, increment_by) + if r == tonumber(increment_by) then + -- The first time this key is set, the value will be equal to increment_by. + -- So we only need the expire command once + redis.call("PEXPIRE", key, window) end - - return num_requests + + return r """ def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None: @@ -197,13 +198,13 @@ def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None: self._max_requests = max_requests self._window = to_ms(window, unit) - def _limit(self, identifier: str) -> Generator: + def _limit(self, identifier: str, rate: int = 1) -> Generator: curr_window = now_ms() // self._window key = f"{identifier}:{curr_window}" num_requests = yield ( "eval", - (FixedWindow.SCRIPT, [key], [self._window]), + (FixedWindow.SCRIPT, [key], [self._window, rate]), ) yield Response( @@ -245,38 +246,36 @@ class SlidingWindow(AbstractLimiter): """ SCRIPT = """ - local key = KEYS[1] -- identifier including prefixes - local prev_key = KEYS[2] -- key of the previous bucket - local max_requests = tonumber(ARGV[1]) -- max requests per window - local now = tonumber(ARGV[2]) -- current timestamp in milliseconds - local window = tonumber(ARGV[3]) -- interval in milliseconds - - local num_requests = redis.call("GET", key) - if num_requests == false then - num_requests = 0 + local current_key = KEYS[1] -- identifier including prefixes + local previous_key = KEYS[2] -- key of the previous bucket + local tokens = tonumber(ARGV[1]) -- tokens per window + local now = ARGV[2] -- current timestamp in milliseconds + local window = ARGV[3] -- interval in milliseconds + local increment_by = ARGV[4] -- increment rate per request at a given value, default is 1 + + local requests_in_current_window = redis.call("GET", current_key) + if requests_in_current_window == false then + requests_in_current_window = 0 end - local prev_num_requests = redis.call("GET", prev_key) - if prev_num_requests == false then - prev_num_requests = 0 + local requests_in_previous_window = redis.call("GET", previous_key) + if requests_in_previous_window == false then + requests_in_previous_window = 0 end - - local prev_window_weight = 1 - ((now % window) / window) - -- requests to consider from prev window - prev_num_requests = math.floor(prev_num_requests * prev_window_weight) - - if num_requests + prev_num_requests >= max_requests then - return -1 + local percentage_in_current = ( now % window ) / window + -- weighted requests to consider from the previous window + requests_in_previous_window = math.floor(( 1 - percentage_in_current ) * requests_in_previous_window) + if requests_in_previous_window + requests_in_current_window >= tokens then + return -1 end - num_requests = redis.call("INCR", key) - if num_requests == 1 then - -- The first time this key is set, the value will be 1. - -- So we only need the expire command once - redis.call("PEXPIRE", key, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second + local new_value = redis.call("INCRBY", current_key, increment_by) + if new_value == tonumber(increment_by) then + -- The first time this key is set, the value will be equal to increment_by. + -- So we only need the expire command once + redis.call("PEXPIRE", current_key, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second end - - return max_requests - (num_requests + prev_num_requests) + return tokens - ( new_value + requests_in_previous_window ) """ def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None: @@ -292,7 +291,7 @@ def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None: self._max_requests = max_requests self._window = to_ms(window, unit) - def _limit(self, identifier: str) -> Generator: + def _limit(self, identifier: str, rate: int = 1) -> Generator: now = now_ms() curr_window = now // self._window @@ -306,7 +305,7 @@ def _limit(self, identifier: str) -> Generator: ( SlidingWindow.SCRIPT, [key, prev_key], - [self._max_requests, now, self._window], + [self._max_requests, now, self._window, rate], ), ) @@ -362,39 +361,40 @@ class TokenBucket(AbstractLimiter): """ SCRIPT = """ - local key = KEYS[1] -- identifier including prefixes - local max_tokens = tonumber(ARGV[1]) -- maximum number of tokens - local interval = tonumber(ARGV[2]) -- size of the interval in milliseconds - local refill_rate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval - local now = tonumber(ARGV[4]) -- current timestamp in milliseconds - + local key = KEYS[1] -- identifier including prefixes + local max_tokens = tonumber(ARGV[1]) -- maximum number of tokens + local interval = tonumber(ARGV[2]) -- size of the window in milliseconds + local refill_rate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval + local now = tonumber(ARGV[4]) -- current timestamp in milliseconds + local increment_by = tonumber(ARGV[5]) -- how many tokens to consume, default is 1 + local bucket = redis.call("HMGET", key, "refilled_at", "tokens") - + local refilled_at local tokens if bucket[1] == false then - refilled_at = now - tokens = max_tokens + refilled_at = now + tokens = max_tokens else - refilled_at = tonumber(bucket[1]) - tokens = tonumber(bucket[2]) + refilled_at = tonumber(bucket[1]) + tokens = tonumber(bucket[2]) end - + if now >= refilled_at + interval then - local num_refills = math.floor((now - refilled_at) / interval) - tokens = math.min(max_tokens, tokens + num_refills * refill_rate) + local num_refills = math.floor((now - refilled_at) / interval) + tokens = math.min(max_tokens, tokens + num_refills * refill_rate) - refilled_at = refilled_at + num_refills * interval + refilled_at = refilled_at + num_refills * interval end if tokens == 0 then - return {-1, refilled_at + interval} + return {-1, refilled_at + interval} end - local remaining = tokens - 1 + local remaining = tokens - increment_by local expire_at = math.ceil(((max_tokens - remaining) / refill_rate)) * interval - + redis.call("HSET", key, "refilled_at", refilled_at, "tokens", remaining) redis.call("PEXPIRE", key, expire_at) return {remaining, refilled_at + interval} @@ -419,13 +419,13 @@ def __init__( self._refill_rate = refill_rate self._interval = to_ms(interval, unit) - def _limit(self, identifier: str) -> Generator: + def _limit(self, identifier: str, rate: int = 1) -> Generator: remaining, refill_at = yield ( "eval", ( TokenBucket.SCRIPT, [identifier], - [self._max_tokens, self._interval, self._refill_rate, now_ms()], + [self._max_tokens, self._interval, self._refill_rate, now_ms(), rate], ), ) diff --git a/upstash_ratelimit/ratelimit.py b/upstash_ratelimit/ratelimit.py index b9fe3b7..bba157a 100644 --- a/upstash_ratelimit/ratelimit.py +++ b/upstash_ratelimit/ratelimit.py @@ -32,7 +32,7 @@ def __init__( self._limiter = limiter self._prefix = prefix - def limit(self, identifier: str) -> Response: + def limit(self, identifier: str, rate: int = 1) -> Response: """ Determines if a request should pass or be rejected based on the identifier and previously chosen ratelimit. @@ -59,12 +59,14 @@ def limit(self, identifier: str) -> Response: :param identifier: Identifier to ratelimit. Use a constant string to \ limit all requests, or user ids, API keys, or IP addresses for \ individual limits. + :param rate: Rate with which to subtract from the limit of the \ + identifier. """ key = f"{self._prefix}:{identifier}" - return self._limiter.limit(self._redis, key) + return self._limiter.limit(self._redis, key, rate) - def block_until_ready(self, identifier: str, timeout: float) -> Response: + def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response: """ Blocks until the request may pass or timeout is reached. @@ -95,6 +97,8 @@ def block_until_ready(self, identifier: str, timeout: float) -> Response: individual limits. :param timeout: Maximum time in seconds to wait until the request \ may pass. + :param rate: Rate with which to subtract from the limit of the \ + identifier. """ if timeout <= 0: @@ -104,7 +108,7 @@ def block_until_ready(self, identifier: str, timeout: float) -> Response: deadline = now_s() + timeout while True: - response = self.limit(identifier) + response = self.limit(identifier, rate) if response.allowed: break