Skip to content

Instantly share code, notes, and snippets.

@maheshambule
Created October 22, 2020 15:27
Show Gist options
  • Save maheshambule/20050c305c5841a3cde3e11d31d09f2e to your computer and use it in GitHub Desktop.
Save maheshambule/20050c305c5841a3cde3e11d31d09f2e to your computer and use it in GitHub Desktop.
Testing GIL load for Pytorch, numpy and BERT
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import threading
import gil_load
from pynvml import *
import concurrent.futures as futures
import click
import traceback
N_THREADS = 4
NPTS = 4096
image_link = "https://upload.wikimedia.org/wikipedia/commons/f/ff/Pizigani_1367_Chart_10MB.jpg"
def numpy_preprocess():
print("running numpy preprocess")
for i in range(2):
x = np.random.randn(NPTS, NPTS)
x[:] = np.fft.fft2(x).real
def torch_preprocess(image_name="Pizigani_1367_Chart_10MB.jpg", batch_size=10):
try:
print(f"running torch preprocess with image_name={image_name} and batch_size={batch_size}")
import torch
from PIL import Image
from torchvision import transforms
images =[]
for i in range(batch_size):
image = Image.open(image_name)
image_processing = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
im = image_processing(image)
images.append(im)
torch.stack(images)
except Exception as e:
traceback.print_exc(file=sys.stdout)
print(e)
print("done torch prepocess")
def predict(compute_unit=0, model_name='bert-base-uncased'):
try:
i = 0
from transformers import BertModel
import torch
import traceback, os, sys
from transformers import BertTokenizer
# Model Load
model = BertModel.from_pretrained(model_name)
if torch.cuda.is_available():
device = torch.device(f"cuda:0")
model.to(device)
print(f"Loaded model {model_name} in GPU#{compute_unit} - {os.getpid()}:{threading.current_thread().ident}")
# Inference
while (i < 10):
tokenizer = BertTokenizer.from_pretrained(model_name)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
device = torch.device(f"cuda:0")
encoded_input.to(device)
e = model(**encoded_input)
# dump_device_info(0)
# cpu_intensive_method()
print(F"Inference number {i}")
i = i + 1
except Exception as e:
traceback.print_exc(file=sys.stdout)
print(e)
def preprocess_predict(image_name="Pizigani_1367_Chart_10MB.jpg", batch_size=10, compute_unit=0):
try:
torch_preprocess(image_name=image_name, batch_size=batch_size)
predict(compute_unit=compute_unit)
except Exception as e:
traceback.print_exc(file=sys.stdout)
print(e)
@click.command()
@click.option('--server', default="thread", type=str)
@click.option('--instances', default=5, type=int)
@click.option('--image_name', default="Pizigani_1367_Chart_10MB.jpg", type=str)
@click.option('--batch_size', default=10, type=int)
@click.option('--target_method', default="torch_preprocess", type=str)
def run_benchmark(server, instances, image_name, batch_size, target_method):
gil_load.init()
gil_load.start()
print(f"=======server={server},instances={instances},target_method={target_method}, image_name={image_name}, batch_size={batch_size}===================")
wait_for = []
executor = ThreadPoolExecutor(max_workers=instances) if server == "thread" \
else ProcessPoolExecutor(max_workers=instances)
with executor as e:
for compute_unit in range(0, instances):
if target_method == 'numpy_preprocess':
e.submit(numpy_preprocess)
elif target_method == 'torch_preprocess':
e.submit(torch_preprocess, image_name, batch_size)
elif target_method == 'predict':
e.submit(predict, compute_unit)
elif target_method == 'preprocess_predict':
e.submit(preprocess_predict, image_name, batch_size, compute_unit)
for f in futures.as_completed(wait_for):
print("printing")
print('main: result: {}'.format(f.result()))
gil_load.stop()
stats = gil_load.get()
print(gil_load.format(stats))
# gil_load.start()
#
# threads = []
# for i in range(N_THREADS):
# thread = threading.Thread(target=predict, daemon=True)
# threads.append(thread)
# thread.start()
#
#
# for thread in threads:
# thread.join()
#
# gil_load.stop()
#
# stats = gil_load.get()
# print(gil_load.format(stats))
if __name__ == "__main__":
run_benchmark()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment