



Posted by: Martin Grner – Keras Product Manager

The Keras team is pleased to announce that Gemma, a family of lightweight, cutting-edge open models built from the same research and technology used to create Gemini models, is now available in the KerasNLP collection. Masu. Thanks to Keras 3, Gemma runs on JAX, PyTorch, and TensorFlow. In this release, Keras also introduces several new features designed specifically for large language models. New LoRA API (Low Rank Adaptation) and parallel training capabilities for large models.

If you would like to see the code samples directly, please go here.

let's start

The Gemma model has portable 2B and 7B parameter sizes, providing a significant improvement over similar open models and some larger models. for example:

Gemma 7B achieved a new best-in-class score on the MMLU language comprehension benchmark with 64.3% correct (compared to 62.5% for Mistral-7B and 54.8% for Llama2-13B). Gemma added +11 percentage points to her grade GSM8K benchmark score – school maths questions (Gemma 7B 46.4% vs. Mistral-7B 35.4%, Llama2-13B 28.7%), correct answer rate for coding assignment HumanEval +6.1 percent (Gemma 7B 32.3% vs. Mistral 7B 26.2%), Llama2 13B 18.3%).

Gemma models are provided with a familiar KerasNLP API and a very readable Keras implementation. You can instantiate a model with one line of code.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(“gemma_2b_en”)

Then run it directly at the text prompt. Yes, it has tokenization built in, but you can easily split it if you want. Read our Keras NLP guide to learn how.

gemma_lm.generate(“Keras”, max_length=32) > “Keras is a popular deep learning framework for neural networks…”

Try it here: Try using the Gemma model

Fine-tuning the Gemma model using LoRA

Thanks to Keras 3, you can choose which backend you want to run your model on. Here's how to switch:

OS.Environment[“KERAS_BACKEND”] = “jax” # or “tensorflow” or “torch”. import keras # import keras after selecting backend

Keras 3 comes with several new features specifically for large language models. The main one is the new LoRA API (Low Rank Adaptation) for parameter-efficient fine-tuning. Here's how to enable it:

gemma_lm.backbone.enable_lora(rank=4) # Note: rank=4 replaces the weight matrix of the associated layer # with the product of two matrices AxB of rank 4, # reducing the number of trainable parameters.

This one line reduces the number of trainable parameters from 2.5 billion to 1.3 million.

Try it here: Fine-tune your Gemma model using LoRA.

Fine-tune Gemma models on multiple GPUs/TPUs

Keras 3 also supports training large models, and Gemma is a great model to try it out. The new Keras distributed API provides data-parallel and model-parallel distributed training options. The new API is intended to be multi-backend, but due to its proven scalability (Gemma models are trained in JAX), it will be implemented only on his JAX backend for the time being.

A distributed setup is useful for fine-tuning the large Gemma 7B. For example, his TPUv3 with 8 TPU cores available for free on Kaggle or his 8 GPU machine on Google Cloud. Here's how to configure your model for distributed training using model parallelism:

device_mesh = keras.distribution.DeviceMesh( (1, 8), # mesh topology

[“batch”, “model”]# Named mesh axes devices=keras.distribution.list_devices() # Actual accelerators ) # Model settings Layout map = keras.distribution.LayoutMap(device_mesh) Layout map[“token_embedding/embeddings”] = (none, “model”) layout map[“decoder_block.*attention.*(query|key|value).*kernel”] = (none, “model”, none) layout map[“decoder_block.*attention_output.*kernel”] = (none, none, “model”) layout map[“decoder_block.*ffw_gating.*kernel”] = (“model”, none) layout map[“decoder_block.*ffw_linear.*kernel”] = (None, “model”) # Set model config and load model model_Parallel = keras.distribution.ModelParallel( device_mesh,layout_map,batch_dim_name=”batch”) keras.distribution.set_distribution(model_Parallel) gemma_lm = keras_nlp. models.GemmaCausalLM .from_preset(“gemma_7b_en”) # Ready: You can now train with model.fit() or generate text with generate().

This code snippet sets eight accelerators into a 1 x 8 matrix. These two dimensions are called “batch” and “model.” The model weights are sharded on the “model” dimension, which is here split across the 8 accelerators, but the data batches are not split because the “batch” dimension is 1.

Try it here: Fine-tune your Gemma model on multiple GPUs/TPUs.

what's next

We will be publishing a guide soon that will show you how to correctly partition your Transformer model and write the 6-line partition settings above. It's not that long, but it won't fit in this post.

You will notice that the layer splits are defined by a regular expression in the layer name. You can check the layer name using this code snippet. I did this to build the LayoutMap above.

# This is only the first Transformer block. # However, they all have the same structure tlayer = gemma_lm.backbone.get_layer('decoder_block_0') for variable in tlayer.weights: print(f'{variable.path:<58} { str(variable.shape):< 16}')

Full GSPMD model parallelism comes into play here, as Keras passes these settings to the powerful XLA compiler and figures out all the other details of the distributed computation.

Please try playing with the Gemma model. There is also an instruction tuning tutorial here that you may find helpful. By the way, if you want to share your tweaked weights with the community, Kaggle Model Hub now supports uploading user-tuned weights. Visit her Gemma model's model page on Kaggle to see what others have already created.

