1import asyncio
2import contextlib
3import collections
4import time
5
6from types import TracebackType
7from typing import Dict, Optional, Type
8
9try: # Python 3.7
10 base = contextlib.AbstractAsyncContextManager
11 _current_task = asyncio.current_task
12except AttributeError:
13 base = object # type: ignore
14 _current_task = asyncio.Task.current_task # type: ignore
15
16class AsyncLeakyBucket(base):
17 """A leaky bucket rate limiter.
18
19 Allows up to max_rate / time_period acquisitions before blocking.
20
21 time_period is measured in seconds; the default is 60.
22
23 """
24 def __init__(
25 self,
26 max_rate: float,
27 time_period: float = 60,
28 loop: Optional[asyncio.AbstractEventLoop] = None
29 ) -> None:
30 self._loop = loop
31 self._max_level = max_rate
32 self._rate_per_sec = max_rate / time_period
33 self._level = 0.0
34 self._last_check = 0.0
35 # queue of waiting futures to signal capacity to
36 self._waiters: Dict[asyncio.Task, asyncio.Future] = collections.OrderedDict()
37
38 def _leak(self) -> None:
39 """Drip out capacity from the bucket."""
40 if self._level:
41 # drip out enough level for the elapsed time since
42 # we last checked
43 elapsed = time.time() - self._last_check
44 decrement = elapsed * self._rate_per_sec
45 self._level = max(self._level - decrement, 0)
46 self._last_check = time.time()
47
48 def has_capacity(self, amount: float = 1) -> bool:
49 """Check if there is enough space remaining in the bucket"""
50 self._leak()
51 requested = self._level + amount
52 # if there are tasks waiting for capacity, signal to the first
53 # there there may be some now (they won't wake up until this task
54 # yields with an await)
55 if requested < self._max_level:
56 for fut in self._waiters.values():
57 if not fut.done():
58 fut.set_result(True)
59 break
60 return self._level + amount <= self._max_level
61
62 async def acquire(self, amount: float = 1) -> None:
63 """Acquire space in the bucket.
64
65 If the bucket is full, block until there is space.
66
67 """
68 if amount > self._max_level:
69 raise ValueError("Can't acquire more than the bucket capacity")
70
71 loop = self._loop or asyncio.get_event_loop()
72 task = _current_task(loop)
73 assert task is not None
74 while not self.has_capacity(amount):
75 # wait for the next drip to have left the bucket
76 # add a future to the _waiters map to be notified
77 # 'early' if capacity has come up
78 fut = loop.create_future()
79 self._waiters[task] = fut
80 try:
81 await asyncio.wait_for(
82 asyncio.shield(fut),
83 1 / self._rate_per_sec * amount,
84 loop=loop
85 )
86 except asyncio.TimeoutError:
87 pass
88 fut.cancel()
89 self._waiters.pop(task, None)
90
91 self._level += amount
92
93 return None
94
95 async def __aenter__(self) -> None:
96 await self.acquire()
97 return None
98
99 async def __aexit__(
100 self,
101 exc_type: Optional[Type[BaseException]],
102 exc: Optional[BaseException],
103 tb: Optional[TracebackType]
104 ) -> None:
105 return None
106
1import asyncio
2from threading import Lock
3
4class PTBNL:
5
6 def __init__(self):
7 self._req_id_seq = 0
8 self._futures = {}
9 self._results = {}
10 self.token_bucket = TokenBucket()
11 self.token_bucket.set_rate(100)
12
13 def run(self, *awaitables):
14
15 loop = asyncio.get_event_loop()
16
17 if not awaitables:
18 loop.run_forever()
19 elif len(awaitables) == 1:
20 return loop.run_until_complete(*awaitables)
21 else:
22 future = asyncio.gather(*awaitables)
23 return loop.run_until_complete(future)
24
25 def sleep(self, secs) -> True:
26
27 self.run(asyncio.sleep(secs))
28 return True
29
30 def get_req_id(self) -> int:
31
32 new_id = self._req_id_seq
33 self._req_id_seq += 1
34 return new_id
35
36 def start_req(self, key):
37
38 loop = asyncio.get_event_loop()
39 future = loop.create_future()
40 self._futures[key] = future
41 return future
42
43 def end_req(self, key, result=None):
44
45 future = self._futures.pop(key, None)
46 if future:
47 if result is None:
48 result = self._results.pop(key, [])
49 if not future.done():
50 future.set_result(result)
51
52 def req_data(self, req_id, obj):
53 # Do Some Work Here
54 self.req_data_end(req_id)
55 pass
56
57 def req_data_end(self, req_id):
58 print(req_id, " has ended")
59 self.end_req(req_id)
60
61 async def req_data_async(self, obj):
62
63 req_id = self.get_req_id()
64 future = self.start_req(req_id)
65
66 self.req_data(req_id, obj)
67
68 await future
69 return future.result()
70
71 async def req_data_batch_async(self, contracts):
72
73 futures = []
74 FLAG = False
75
76 for contract in contracts:
77 req_id = self.get_req_id()
78 future = self.start_req(req_id)
79 futures.append(future)
80
81 nap = self.token_bucket.consume(1)
82
83 if FLAG is False:
84 FLAG = True
85 start = asyncio.get_event_loop().time()
86
87 asyncio.get_event_loop().call_later(nap, self.req_data, req_id, contract)
88
89 await asyncio.gather(*futures)
90 elapsed = asyncio.get_event_loop().time() - start
91
92 return futures, len(contracts)/elapsed
93
94class TokenBucket:
95
96 def __init__(self):
97 self.tokens = 0
98 self.rate = 0
99 self.last = asyncio.get_event_loop().time()
100 self.lock = Lock()
101
102 def set_rate(self, rate):
103 with self.lock:
104 self.rate = rate
105 self.tokens = self.rate
106
107 def consume(self, tokens):
108 with self.lock:
109 if not self.rate:
110 return 0
111
112 now = asyncio.get_event_loop().time()
113 lapse = now - self.last
114 self.last = now
115 self.tokens += lapse * self.rate
116
117 if self.tokens > self.rate:
118 self.tokens = self.rate
119
120 self.tokens -= tokens
121
122 if self.tokens >= 0:
123 return 0
124 else:
125 return -self.tokens / self.rate
126
127
128if __name__ == '__main__':
129
130 asyncio.get_event_loop().set_debug(True)
131 app = PTBNL()
132
133 objs = [obj for obj in range(500)]
134
135 l,t = app.run(app.req_data_batch_async(objs))
136
137 print(l)
138 print(t)
139