pip install jax==0.4.2
pip install jaxlib==0.4.2+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install jax==0.4.2
pip install jaxlib==0.4.2+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html