JAX/Flax Implementation

#13
by lemon-mint - opened

DeepMind's Gemma implementation does not seem to have been updated in accordance with the new release.

Are there any plans to release the JAX/Flax implementation and model?

lemon-mint changed discussion title from JAX/Flax implementation to JAX/Flax Implementation
Google org

There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?

For my own curiosity why are you interested in flax/jax in particular?

I think using TPU is the most cost-effective way to full fine-tune the 27B model.

Additionally, the JAX/Flax implementation is good to use as a reference implementation. Last time, in Gemma 1, DeepMind's implementation was the only one without bugs.

There is! Our focus was on getting the weights out properly. For my own curiosity why are you interested in flax/jax in particular?

@canyon289 This would be very convenient. I want to integrate with our JORA library (Jax centered LLM PEFT finetuning). I believe the only differences from Gemma 1/1.1 are

  • Logit softcaps,
  • Sliding Window Attention, and
  • query normalization

Plus, the weights in Flax format (i.e. orbax.checkpoint)

Google org

Thank you both for the answers. There's a couple of other changes such as GQA! Regardless its still being worked on, it should be out soonish. My apologies for the delay

JORA looks interesting! I'd suggest adding a link to the paper in the readme.

Google org

We haven't forgotten about this. We're making some final changes and its on its way to release

Google org

I'd also suggest sending a PR to add it to https://github.com/n2cholas/awesome-jax

Its updated! Check out it folks. Hope you enjoy the models

@canyon289 Hi, could you check where the implementation with jax/flax of the model? I couldn't find python code related with gemma 2 implementation, rather, there are only weight files on Kaggle.

The official JAX repo has the configurations for Gemma 2: https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py

Sign up or log in to comment