We initially started with the goal of fine-tuning LLaMA 3 on TPUs, but PyTorch XLA was clunky, so we decided to rewrite the model in JAX. That said, as mentioned earlier in the thread, we also believe JAX is a better platform for non-NVIDIA GPUs and want to build on JAX+openXLA for building infra for non-NVIDIA GPUs.
We initially started with the goal of fine-tuning LLaMA 3 on TPUs, but PyTorch XLA was clunky, so we decided to rewrite the model in JAX. That said, as mentioned earlier in the thread, we also believe JAX is a better platform for non-NVIDIA GPUs and want to build on JAX+openXLA for building infra for non-NVIDIA GPUs.