Skip to content

Instantly share code, notes, and snippets.

@jcrist

jcrist/bench.py Secret

Last active February 10, 2022 08:14
Show Gist options
  • Save jcrist/c6336718edaabde21f2e1c269ac2bc88 to your computer and use it in GitHub Desktop.
Save jcrist/c6336718edaabde21f2e1c269ac2bc88 to your computer and use it in GitHub Desktop.
Dask comms benchmark
import argparse
import asyncio
import datetime
import os
import threading
import time
import timeit
from concurrent.futures import ProcessPoolExecutor
from distributed.comm import listen, connect, CommClosedError
from distributed.protocol.serialize import Serialized
from dask.utils import parse_timedelta, format_time
DIR = os.path.abspath(os.path.dirname(__file__))
CERT = os.path.join(DIR, "bench-cert.pem")
KEY = os.path.join(DIR, "bench-key.pem")
class GILHolder(threading.Thread):
def __init__(self, on_time, off_time=0):
self.on_time = on_time
self.off_time = off_time
self.time_once = (
(timeit.timeit("sum(range(1000))", number=10_000) / 10_000)
if self.off_time is not None
else None
)
super().__init__(daemon=True, name="GIL-holder")
def run(self):
if not self.off_time:
while True:
pass
else:
N = round(self.on_time / self.time_once)
while True:
# Hold GIL for ~on_time
for _ in range(N):
sum(range(1000))
# Release GIL for ~off_time
time.sleep(self.off_time)
def ensure_certs():
if not (os.path.exists(KEY) and os.path.exists(CERT)):
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
key_bytes = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
subject = issuer = x509.Name(
[x509.NameAttribute(NameOID.COMMON_NAME, "ery-bench")]
)
now = datetime.datetime.utcnow()
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=365))
.sign(key, hashes.SHA256(), default_backend())
)
cert_bytes = cert.public_bytes(serialization.Encoding.PEM)
with open(CERT, "wb") as f:
f.write(cert_bytes)
with open(KEY, "wb") as f:
f.write(key_bytes)
def get_ssl_context():
import ssl
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.load_cert_chain(CERT, KEY)
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
return context
class BenchProc:
def __init__(self, n_bytes, n_frames, n_seconds, n_clients, use_tls):
self.n_bytes = n_bytes
self.n_frames = n_frames
self.n_seconds = n_seconds
self.n_clients = n_clients
self.use_tls = use_tls
async def client(self, address, comm_kwargs, msg):
comm = await connect(address, deserialize=False, **comm_kwargs)
while self.running:
await comm.write(msg)
await comm.read()
await comm.close()
def stop(self):
self.running = False
async def run(self):
self.running = True
if self.use_tls:
kwargs = {"ssl_context": get_ssl_context()}
prefix = "tls"
else:
kwargs = {}
prefix = "tcp"
address = f"{prefix}://127.0.0.1:8080"
msg = Serialized({}, [os.urandom(self.n_bytes) for _ in range(self.n_frames)])
loop = asyncio.get_running_loop()
loop.call_later(self.n_seconds, self.stop)
tasks = [
asyncio.create_task(self.client(address, kwargs, msg))
for _ in range(self.n_clients)
]
await asyncio.gather(*tasks, return_exceptions=True)
def bench_proc_main(n_bytes, n_frames, n_seconds, n_clients, use_tls):
bench = BenchProc(n_bytes, n_frames, n_seconds, n_clients, use_tls)
asyncio.run(bench.run())
async def run(n_bytes, n_frames, n_seconds, n_procs, n_clients, use_tls):
if use_tls:
ensure_certs()
kwargs = {"ssl_context": get_ssl_context()}
prefix = "tls"
else:
kwargs = {}
prefix = "tcp"
address = f"{prefix}://127.0.0.1:8080"
loop = asyncio.get_running_loop()
connections = n_procs * n_clients
count_var = 0
count = 0
start_time = 0
stop_time = 0
async def handle_comm(comm):
nonlocal connections
nonlocal count_var
nonlocal count
nonlocal start_time
nonlocal stop_time
connections -= 1
if not connections:
# All connected, record time and restart counter
start_time = time.time()
count_var = 0
try:
while True:
msg = await comm.read()
await comm.write(msg=msg)
count_var += 1
except CommClosedError:
if not stop_time:
# First disconnected, stash count and record time
stop_time = time.time()
count = count_var
async with listen(address, handle_comm, deserialize=False, **kwargs):
with ProcessPoolExecutor(max_workers=n_procs) as executor:
tasks = [
loop.run_in_executor(
executor,
bench_proc_main,
n_bytes,
n_frames,
n_seconds,
n_clients,
use_tls,
)
for _ in range(n_procs)
]
await asyncio.gather(*tasks)
n_seconds = stop_time - start_time
print(f"{count / n_seconds} RPS")
print(f"{n_seconds / count * 1e6} us per request")
print(f"{n_bytes * count / (n_seconds * 1e6)} MB/s each way")
def main():
parser = argparse.ArgumentParser(description="Benchmark channels")
parser.add_argument(
"--procs", "-p", default=1, type=int, help="Number of client processes"
)
parser.add_argument(
"--concurrency",
"-c",
default=1,
type=int,
help="Number of clients per process",
)
parser.add_argument(
"--frames",
"-f",
default=1,
type=int,
help="Number of frames per message",
)
parser.add_argument(
"--bytes", "-b", default=1000, type=float, help="total payload size in bytes"
)
parser.add_argument(
"--seconds", "-s", default=5, type=parse_timedelta, help="bench duration in secs"
)
parser.add_argument(
"--gil-hold-time", default=0, type=parse_timedelta, help="how long a background thread should hold the GIL"
)
parser.add_argument(
"--gil-release-time", default=0, type=parse_timedelta, help="how long a background thread should release the GIL"
)
parser.add_argument("--uvloop", action="store_true", help="Whether to use uvloop")
parser.add_argument("--tls", action="store_true", help="Whether to use TLS")
args = parser.parse_args()
if args.uvloop:
import uvloop
uvloop.install()
n_bytes = int(args.bytes)
print(
f"processes = {args.procs}, "
f"concurrency = {args.concurrency}, "
f"bytes = {n_bytes}, "
f"frames = {args.frames}, "
f"seconds = {args.seconds}, "
f"uvloop = {args.uvloop}, "
f"tls = {args.tls}"
)
if args.gil_hold_time:
print(
f"gil_hold_time = {format_time(args.gil_hold_time)}, "
f"gil_release_time = {format_time(args.gil_release_time)}"
)
GILHolder(args.gil_hold_time, args.gil_release_time).start()
asyncio.run(
run(n_bytes, args.frames, args.seconds, args.procs, args.concurrency, args.tls)
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment