Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Custom Rates #13

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,30 @@ ratelimit = Ratelimit(
)
```

# Custom Rates

When rate limiting, you may want different requests to consume different amounts of tokens.
This could be useful when processing batches of requests where you want to rate limit based
on items in the batch or when you want to rate limit based on the number of tokens.

To achieve this, you can simply pass `rate` parameter when calling the limit method:

```python

from upstash_ratelimit import Ratelimit, FixedWindow
from upstash_redis import Redis

ratelimit = Ratelimit(
redis=Redis.from_env(),
limiter=FixedWindow(max_requests=10, window=10),
)

# pass rate as 5 to subtract 5 from the number of
# allowed requests in the window:
identifier = "api"
response = ratelimit.limit(identifier, rate=5)
```

# Contributing

## Preparing the environment
Expand Down
17 changes: 17 additions & 0 deletions tests/test_fixed_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,20 @@ def test_get_reset(redis: Redis) -> None:

with patch("time.time", return_value=1688910786.167):
assert ratelimit.get_reset(random_id()) == approx(1688910790.0)


def test_custom_rate(redis: Redis) -> None:
ratelimit = Ratelimit(
redis=redis,
limiter=FixedWindow(max_requests=10, window=1, unit="d"),
)
rate = 2

id = random_id()

ratelimit.limit(id)
ratelimit.limit(id, rate)
assert ratelimit.get_remaining(id) == 7

ratelimit.limit(id, rate)
assert ratelimit.get_remaining(id) == 5
17 changes: 17 additions & 0 deletions tests/test_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,20 @@ def test_get_reset(redis: Redis) -> None:

with patch("time.time", return_value=1688910786.167):
assert ratelimit.get_reset(random_id()) == approx(1688910790.0)


def test_custom_rate(redis: Redis) -> None:
ratelimit = Ratelimit(
redis=redis,
limiter=SlidingWindow(max_requests=10, window=5),
)
rate = 2

id = random_id()

ratelimit.limit(id)
ratelimit.limit(id, rate)
assert ratelimit.get_remaining(id) == 7

ratelimit.limit(id, rate)
assert ratelimit.get_remaining(id) == 5
17 changes: 17 additions & 0 deletions tests/test_token_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,20 @@ def test_get_reset_with_refills_that_should_be_made(redis: Redis) -> None:
time.sleep(3)

assert ratelimit.get_reset(id) >= last_reset + 2


def test_custom_rate(redis: Redis) -> None:
ratelimit = Ratelimit(
redis=redis,
limiter=TokenBucket(max_tokens=10, refill_rate=1, interval=1),
)
rate = 2

id = random_id()

ratelimit.limit(id)
ratelimit.limit(id, rate)
assert ratelimit.get_remaining(id) == 7

ratelimit.limit(id, rate)
assert ratelimit.get_remaining(id) == 5
12 changes: 8 additions & 4 deletions upstash_ratelimit/asyncio/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self._limiter = limiter
self._prefix = prefix

async def limit(self, identifier: str) -> Response:
async def limit(self, identifier: str, rate: int = 1) -> Response:
"""
Determines if a request should pass or be rejected based on the identifier
and previously chosen ratelimit.
Expand Down Expand Up @@ -60,12 +60,14 @@ async def main() -> None:
:param identifier: Identifier to ratelimit. Use a constant string to \
limit all requests, or user ids, API keys, or IP addresses for \
individual limits.
:param rate: Rate with which to subtract from the limit of the \
identifier.
"""

key = f"{self._prefix}:{identifier}"
return await self._limiter.limit_async(self._redis, key)
return await self._limiter.limit_async(self._redis, key, rate)

async def block_until_ready(self, identifier: str, timeout: float) -> Response:
async def block_until_ready(self, identifier: str, timeout: float, rate: int = 1) -> Response:
"""
Blocks until the request may pass or timeout is reached.

Expand Down Expand Up @@ -97,6 +99,8 @@ async def main() -> None:
individual limits.
:param timeout: Maximum time in seconds to wait until the request \
may pass.
:param rate: Rate with which to subtract from the limit of the \
identifier.
"""

if timeout <= 0:
Expand All @@ -106,7 +110,7 @@ async def main() -> None:
deadline = now_s() + timeout

while True:
response = await self.limit(identifier)
response = await self.limit(identifier, rate)
if response.allowed:
break

Expand Down
134 changes: 67 additions & 67 deletions upstash_ratelimit/limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class Response:

class Limiter(abc.ABC):
@abc.abstractmethod
def limit(self, redis: Redis, identifier: str) -> Response:
def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response:
pass

@abc.abstractmethod
async def limit_async(self, redis: AsyncRedis, identifier: str) -> Response:
async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response:
pass

@abc.abstractmethod
Expand Down Expand Up @@ -108,16 +108,16 @@ async def _with_at_most_one_request_async(

class AbstractLimiter(Limiter):
@abc.abstractmethod
def _limit(self, identifier: str) -> Generator:
def _limit(self, identifier: str, rate: int = 1) -> Generator:
pass

def limit(self, redis: Redis, identifier: str) -> Response:
response: Response = _with_at_most_one_request(redis, self._limit(identifier))
def limit(self, redis: Redis, identifier: str, rate: int = 1) -> Response:
response: Response = _with_at_most_one_request(redis, self._limit(identifier, rate))
return response

async def limit_async(self, redis: AsyncRedis, identifier: str) -> Response:
async def limit_async(self, redis: AsyncRedis, identifier: str, rate: int = 1) -> Response:
response: Response = await _with_at_most_one_request_async(
redis, self._limit(identifier)
redis, self._limit(identifier, rate)
)
return response

Expand Down Expand Up @@ -171,17 +171,18 @@ class FixedWindow(AbstractLimiter):
"""

SCRIPT = """
local key = KEYS[1] -- identifier including prefixes
local window = ARGV[1] -- interval in milliseconds

local num_requests = redis.call("INCR", key)
if num_requests == 1 then
-- The first time this key is set, the value will be 1.
-- So we only need the expire command once
redis.call("PEXPIRE", key, window)
local key = KEYS[1]
local window = ARGV[1]
local increment_by = ARGV[2] -- increment rate per request at a given value, default is 1

local r = redis.call("INCRBY", key, increment_by)
if r == tonumber(increment_by) then
-- The first time this key is set, the value will be equal to increment_by.
-- So we only need the expire command once
redis.call("PEXPIRE", key, window)
end
return num_requests

return r
"""

def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None:
Expand All @@ -197,13 +198,13 @@ def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None:
self._max_requests = max_requests
self._window = to_ms(window, unit)

def _limit(self, identifier: str) -> Generator:
def _limit(self, identifier: str, rate: int = 1) -> Generator:
curr_window = now_ms() // self._window
key = f"{identifier}:{curr_window}"

num_requests = yield (
"eval",
(FixedWindow.SCRIPT, [key], [self._window]),
(FixedWindow.SCRIPT, [key], [self._window, rate]),
)

yield Response(
Expand Down Expand Up @@ -245,38 +246,36 @@ class SlidingWindow(AbstractLimiter):
"""

SCRIPT = """
local key = KEYS[1] -- identifier including prefixes
local prev_key = KEYS[2] -- key of the previous bucket
local max_requests = tonumber(ARGV[1]) -- max requests per window
local now = tonumber(ARGV[2]) -- current timestamp in milliseconds
local window = tonumber(ARGV[3]) -- interval in milliseconds

local num_requests = redis.call("GET", key)
if num_requests == false then
num_requests = 0
local current_key = KEYS[1] -- identifier including prefixes
local previous_key = KEYS[2] -- key of the previous bucket
local tokens = tonumber(ARGV[1]) -- tokens per window
local now = ARGV[2] -- current timestamp in milliseconds
local window = ARGV[3] -- interval in milliseconds
local increment_by = ARGV[4] -- increment rate per request at a given value, default is 1

local requests_in_current_window = redis.call("GET", current_key)
if requests_in_current_window == false then
requests_in_current_window = 0
end

local prev_num_requests = redis.call("GET", prev_key)
if prev_num_requests == false then
prev_num_requests = 0
local requests_in_previous_window = redis.call("GET", previous_key)
if requests_in_previous_window == false then
requests_in_previous_window = 0
end

local prev_window_weight = 1 - ((now % window) / window)
-- requests to consider from prev window
prev_num_requests = math.floor(prev_num_requests * prev_window_weight)

if num_requests + prev_num_requests >= max_requests then
return -1
local percentage_in_current = ( now % window ) / window
-- weighted requests to consider from the previous window
requests_in_previous_window = math.floor(( 1 - percentage_in_current ) * requests_in_previous_window)
if requests_in_previous_window + requests_in_current_window >= tokens then
return -1
end

num_requests = redis.call("INCR", key)
if num_requests == 1 then
-- The first time this key is set, the value will be 1.
-- So we only need the expire command once
redis.call("PEXPIRE", key, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
local new_value = redis.call("INCRBY", current_key, increment_by)
if new_value == tonumber(increment_by) then
-- The first time this key is set, the value will be equal to increment_by.
-- So we only need the expire command once
redis.call("PEXPIRE", current_key, window * 2 + 1000) -- Enough time to overlap with a new window + 1 second
end

return max_requests - (num_requests + prev_num_requests)
return tokens - ( new_value + requests_in_previous_window )
"""

def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None:
Expand All @@ -292,7 +291,7 @@ def __init__(self, max_requests: int, window: int, unit: UnitT = "s") -> None:
self._max_requests = max_requests
self._window = to_ms(window, unit)

def _limit(self, identifier: str) -> Generator:
def _limit(self, identifier: str, rate: int = 1) -> Generator:
now = now_ms()

curr_window = now // self._window
Expand All @@ -306,7 +305,7 @@ def _limit(self, identifier: str) -> Generator:
(
SlidingWindow.SCRIPT,
[key, prev_key],
[self._max_requests, now, self._window],
[self._max_requests, now, self._window, rate],
),
)

Expand Down Expand Up @@ -362,39 +361,40 @@ class TokenBucket(AbstractLimiter):
"""

SCRIPT = """
local key = KEYS[1] -- identifier including prefixes
local max_tokens = tonumber(ARGV[1]) -- maximum number of tokens
local interval = tonumber(ARGV[2]) -- size of the interval in milliseconds
local refill_rate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds

local key = KEYS[1] -- identifier including prefixes
local max_tokens = tonumber(ARGV[1]) -- maximum number of tokens
local interval = tonumber(ARGV[2]) -- size of the window in milliseconds
local refill_rate = tonumber(ARGV[3]) -- how many tokens are refilled after each interval
local now = tonumber(ARGV[4]) -- current timestamp in milliseconds
local increment_by = tonumber(ARGV[5]) -- how many tokens to consume, default is 1

local bucket = redis.call("HMGET", key, "refilled_at", "tokens")

local refilled_at
local tokens

if bucket[1] == false then
refilled_at = now
tokens = max_tokens
refilled_at = now
tokens = max_tokens
else
refilled_at = tonumber(bucket[1])
tokens = tonumber(bucket[2])
refilled_at = tonumber(bucket[1])
tokens = tonumber(bucket[2])
end

if now >= refilled_at + interval then
local num_refills = math.floor((now - refilled_at) / interval)
tokens = math.min(max_tokens, tokens + num_refills * refill_rate)
local num_refills = math.floor((now - refilled_at) / interval)
tokens = math.min(max_tokens, tokens + num_refills * refill_rate)

refilled_at = refilled_at + num_refills * interval
refilled_at = refilled_at + num_refills * interval
end

if tokens == 0 then
return {-1, refilled_at + interval}
return {-1, refilled_at + interval}
end

local remaining = tokens - 1
local remaining = tokens - increment_by
local expire_at = math.ceil(((max_tokens - remaining) / refill_rate)) * interval

redis.call("HSET", key, "refilled_at", refilled_at, "tokens", remaining)
redis.call("PEXPIRE", key, expire_at)
return {remaining, refilled_at + interval}
Expand All @@ -419,13 +419,13 @@ def __init__(
self._refill_rate = refill_rate
self._interval = to_ms(interval, unit)

def _limit(self, identifier: str) -> Generator:
def _limit(self, identifier: str, rate: int = 1) -> Generator:
remaining, refill_at = yield (
"eval",
(
TokenBucket.SCRIPT,
[identifier],
[self._max_tokens, self._interval, self._refill_rate, now_ms()],
[self._max_tokens, self._interval, self._refill_rate, now_ms(), rate],
),
)

Expand Down
Loading
Loading