-
-
Save jcrist/c6336718edaabde21f2e1c269ac2bc88 to your computer and use it in GitHub Desktop.
Dask comms benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import asyncio | |
import datetime | |
import os | |
import threading | |
import time | |
import timeit | |
from concurrent.futures import ProcessPoolExecutor | |
from distributed.comm import listen, connect, CommClosedError | |
from distributed.protocol.serialize import Serialized | |
from dask.utils import parse_timedelta, format_time | |
DIR = os.path.abspath(os.path.dirname(__file__)) | |
CERT = os.path.join(DIR, "bench-cert.pem") | |
KEY = os.path.join(DIR, "bench-key.pem") | |
class GILHolder(threading.Thread): | |
def __init__(self, on_time, off_time=0): | |
self.on_time = on_time | |
self.off_time = off_time | |
self.time_once = ( | |
(timeit.timeit("sum(range(1000))", number=10_000) / 10_000) | |
if self.off_time is not None | |
else None | |
) | |
super().__init__(daemon=True, name="GIL-holder") | |
def run(self): | |
if not self.off_time: | |
while True: | |
pass | |
else: | |
N = round(self.on_time / self.time_once) | |
while True: | |
# Hold GIL for ~on_time | |
for _ in range(N): | |
sum(range(1000)) | |
# Release GIL for ~off_time | |
time.sleep(self.off_time) | |
def ensure_certs(): | |
if not (os.path.exists(KEY) and os.path.exists(CERT)): | |
from cryptography import x509 | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.hazmat.primitives import hashes | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa | |
from cryptography.x509.oid import NameOID | |
key = rsa.generate_private_key( | |
public_exponent=65537, key_size=2048, backend=default_backend() | |
) | |
key_bytes = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.PKCS8, | |
encryption_algorithm=serialization.NoEncryption(), | |
) | |
subject = issuer = x509.Name( | |
[x509.NameAttribute(NameOID.COMMON_NAME, "ery-bench")] | |
) | |
now = datetime.datetime.utcnow() | |
cert = ( | |
x509.CertificateBuilder() | |
.subject_name(subject) | |
.issuer_name(issuer) | |
.public_key(key.public_key()) | |
.serial_number(x509.random_serial_number()) | |
.not_valid_before(now) | |
.not_valid_after(now + datetime.timedelta(days=365)) | |
.sign(key, hashes.SHA256(), default_backend()) | |
) | |
cert_bytes = cert.public_bytes(serialization.Encoding.PEM) | |
with open(CERT, "wb") as f: | |
f.write(cert_bytes) | |
with open(KEY, "wb") as f: | |
f.write(key_bytes) | |
def get_ssl_context(): | |
import ssl | |
context = ssl.SSLContext(ssl.PROTOCOL_TLS) | |
context.load_cert_chain(CERT, KEY) | |
context.check_hostname = False | |
context.verify_mode = ssl.CERT_NONE | |
return context | |
class BenchProc: | |
def __init__(self, n_bytes, n_frames, n_seconds, n_clients, use_tls): | |
self.n_bytes = n_bytes | |
self.n_frames = n_frames | |
self.n_seconds = n_seconds | |
self.n_clients = n_clients | |
self.use_tls = use_tls | |
async def client(self, address, comm_kwargs, msg): | |
comm = await connect(address, deserialize=False, **comm_kwargs) | |
while self.running: | |
await comm.write(msg) | |
await comm.read() | |
await comm.close() | |
def stop(self): | |
self.running = False | |
async def run(self): | |
self.running = True | |
if self.use_tls: | |
kwargs = {"ssl_context": get_ssl_context()} | |
prefix = "tls" | |
else: | |
kwargs = {} | |
prefix = "tcp" | |
address = f"{prefix}://127.0.0.1:8080" | |
msg = Serialized({}, [os.urandom(self.n_bytes) for _ in range(self.n_frames)]) | |
loop = asyncio.get_running_loop() | |
loop.call_later(self.n_seconds, self.stop) | |
tasks = [ | |
asyncio.create_task(self.client(address, kwargs, msg)) | |
for _ in range(self.n_clients) | |
] | |
await asyncio.gather(*tasks, return_exceptions=True) | |
def bench_proc_main(n_bytes, n_frames, n_seconds, n_clients, use_tls): | |
bench = BenchProc(n_bytes, n_frames, n_seconds, n_clients, use_tls) | |
asyncio.run(bench.run()) | |
async def run(n_bytes, n_frames, n_seconds, n_procs, n_clients, use_tls): | |
if use_tls: | |
ensure_certs() | |
kwargs = {"ssl_context": get_ssl_context()} | |
prefix = "tls" | |
else: | |
kwargs = {} | |
prefix = "tcp" | |
address = f"{prefix}://127.0.0.1:8080" | |
loop = asyncio.get_running_loop() | |
connections = n_procs * n_clients | |
count_var = 0 | |
count = 0 | |
start_time = 0 | |
stop_time = 0 | |
async def handle_comm(comm): | |
nonlocal connections | |
nonlocal count_var | |
nonlocal count | |
nonlocal start_time | |
nonlocal stop_time | |
connections -= 1 | |
if not connections: | |
# All connected, record time and restart counter | |
start_time = time.time() | |
count_var = 0 | |
try: | |
while True: | |
msg = await comm.read() | |
await comm.write(msg=msg) | |
count_var += 1 | |
except CommClosedError: | |
if not stop_time: | |
# First disconnected, stash count and record time | |
stop_time = time.time() | |
count = count_var | |
async with listen(address, handle_comm, deserialize=False, **kwargs): | |
with ProcessPoolExecutor(max_workers=n_procs) as executor: | |
tasks = [ | |
loop.run_in_executor( | |
executor, | |
bench_proc_main, | |
n_bytes, | |
n_frames, | |
n_seconds, | |
n_clients, | |
use_tls, | |
) | |
for _ in range(n_procs) | |
] | |
await asyncio.gather(*tasks) | |
n_seconds = stop_time - start_time | |
print(f"{count / n_seconds} RPS") | |
print(f"{n_seconds / count * 1e6} us per request") | |
print(f"{n_bytes * count / (n_seconds * 1e6)} MB/s each way") | |
def main(): | |
parser = argparse.ArgumentParser(description="Benchmark channels") | |
parser.add_argument( | |
"--procs", "-p", default=1, type=int, help="Number of client processes" | |
) | |
parser.add_argument( | |
"--concurrency", | |
"-c", | |
default=1, | |
type=int, | |
help="Number of clients per process", | |
) | |
parser.add_argument( | |
"--frames", | |
"-f", | |
default=1, | |
type=int, | |
help="Number of frames per message", | |
) | |
parser.add_argument( | |
"--bytes", "-b", default=1000, type=float, help="total payload size in bytes" | |
) | |
parser.add_argument( | |
"--seconds", "-s", default=5, type=parse_timedelta, help="bench duration in secs" | |
) | |
parser.add_argument( | |
"--gil-hold-time", default=0, type=parse_timedelta, help="how long a background thread should hold the GIL" | |
) | |
parser.add_argument( | |
"--gil-release-time", default=0, type=parse_timedelta, help="how long a background thread should release the GIL" | |
) | |
parser.add_argument("--uvloop", action="store_true", help="Whether to use uvloop") | |
parser.add_argument("--tls", action="store_true", help="Whether to use TLS") | |
args = parser.parse_args() | |
if args.uvloop: | |
import uvloop | |
uvloop.install() | |
n_bytes = int(args.bytes) | |
print( | |
f"processes = {args.procs}, " | |
f"concurrency = {args.concurrency}, " | |
f"bytes = {n_bytes}, " | |
f"frames = {args.frames}, " | |
f"seconds = {args.seconds}, " | |
f"uvloop = {args.uvloop}, " | |
f"tls = {args.tls}" | |
) | |
if args.gil_hold_time: | |
print( | |
f"gil_hold_time = {format_time(args.gil_hold_time)}, " | |
f"gil_release_time = {format_time(args.gil_release_time)}" | |
) | |
GILHolder(args.gil_hold_time, args.gil_release_time).start() | |
asyncio.run( | |
run(n_bytes, args.frames, args.seconds, args.procs, args.concurrency, args.tls) | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment