|
# TPU support |
|
|
|
Lit-LLaMA used `lightning.Fabric` under the hood, which itself supports TPUs (via [PyTorch XLA](https://github.com/pytorch/xla)). |
|
|
|
The following commands will allow you to set up a `Google Cloud` instance with a [TPU v4](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm) VM: |
|
|
|
```shell |
|
gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b |
|
gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b |
|
``` |
|
|
|
Now that you are in the machine, let's clone the repository and install the dependencies |
|
|
|
```shell |
|
git clone https://github.com/Lightning-AI/lit-llama |
|
cd lit-llama |
|
pip install -r requirements.txt |
|
``` |
|
|
|
By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables |
|
|
|
```shell |
|
export PJRT_DEVICE=TPU |
|
export ALLOW_MULTIPLE_LIBTPU_LOAD=1 |
|
``` |
|
|
|
> **Note** |
|
> You can find an extensive guide on how to get set-up and all the available options [here](https://cloud.google.com/tpu/docs/v4-users-guide). |
|
|
|
Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with `gcloud compute tpus tpu-vm scp` or you can follow the steps described in our [downloading guide](download_weights.md). |
|
|
|
## Inference |
|
|
|
Generation works out-of-the-box with TPUs: |
|
|
|
```shell |
|
python3 generate.py --prompt "Hello, my name is" --num_samples 3 |
|
``` |
|
|
|
This command will take take ~20s for the first generation time as XLA needs to compile the graph. |
|
You'll notice that afterwards, generation times drop to ~5s. |
|
|
|
## Finetuning |
|
|
|
Coming soon. |
|
|
|
> **Warning** |
|
> When you are done, remember to delete your instance |
|
> ```shell |
|
> gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b |
|
> ``` |