mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-01-19 01:21:01 +08:00
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import threading
|
|
import time
|
|
|
|
|
|
class TokenBucket:
|
|
def __init__(self, tpm, timeout=None):
|
|
self.capacity = int(tpm) # 令牌桶容量
|
|
self.tokens = 0 # 初始令牌数为0
|
|
self.rate = int(tpm) / 60 # 令牌每秒生成速率
|
|
self.timeout = timeout # 等待令牌超时时间
|
|
self.cond = threading.Condition() # 条件变量
|
|
self.is_running = True
|
|
# 开启令牌生成线程
|
|
threading.Thread(target=self._generate_tokens).start()
|
|
|
|
def _generate_tokens(self):
|
|
"""生成令牌"""
|
|
while self.is_running:
|
|
with self.cond:
|
|
if self.tokens < self.capacity:
|
|
self.tokens += 1
|
|
self.cond.notify() # 通知获取令牌的线程
|
|
time.sleep(1 / self.rate)
|
|
|
|
def get_token(self):
|
|
"""获取令牌"""
|
|
with self.cond:
|
|
while self.tokens <= 0:
|
|
flag = self.cond.wait(self.timeout)
|
|
if not flag: # 超时
|
|
return False
|
|
self.tokens -= 1
|
|
return True
|
|
|
|
def close(self):
|
|
self.is_running = False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶
|
|
# token_bucket = TokenBucket(20, 0.1)
|
|
for i in range(3):
|
|
if token_bucket.get_token():
|
|
print(f"第{i+1}次请求成功")
|
|
token_bucket.close()
|