Gemma
This model was released on 2024-03-13 and added to Hugging Face Transformers on 2024-02-21.
Gemma is a family of lightweight language models with pretrained and instruction-tuned variants, available in 2B and 7B parameters. The architecture is based on a transformer decoder-only design. It features Multi-Query Attention, rotary positional embeddings (RoPE), GeGLU activation functions, and RMSNorm layer normalization.
The instruction-tuned variant was fine-tuned with supervised learning on instruction-following data, followed by reinforcement learning from human feedback (RLHF) to align the model outputs with human preferences.
You can find all the original Gemma checkpoints under the Gemma release.
The example below demonstrates how to generate text with Pipeline or the AutoModel class, and from the command line.
import torchfrom transformers import pipeline
pipeline = pipeline( task="text-generation", model="google/gemma-2b", dtype=torch.bfloat16, device_map="auto",)
pipeline("LLMs generate text through a process known as", max_new_tokens=50)import torchfrom transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b", dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa")
input_text = "LLMs generate text through a process known as"input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**input_ids, max_new_tokens=50, cache_implementation="static")print(tokenizer.decode(outputs[0], skip_special_tokens=True))echo -e "LLMs generate text through a process known as" | transformers run --task text-generation --model google/gemma-2b --device 0Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.
The example below uses bitsandbytes to only quantize the weights to int4.
#!pip install bitsandbytesimport torchfrom transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4")tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")model = AutoModelForCausalLM.from_pretrained( "google/gemma-7b", quantization_config=quantization_config, device_map="auto", attn_implementation="sdpa")
input_text = "LLMs generate text through a process known as."input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)outputs = model.generate( **input_ids, max_new_tokens=50, cache_implementation="static")print(tokenizer.decode(outputs[0], skip_special_tokens=True))Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-2b")visualizer("LLMs generate text through a process known as")
-
The original Gemma models support standard kv-caching used in many transformer-based language models. You can use use the default
DynamicCacheinstance or a tuple of tensors for past key values during generation. This makes it compatible with typical autoregressive generation workflows.import torchfrom transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCachetokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")model = AutoModelForCausalLM.from_pretrained("google/gemma-2b",dtype=torch.bfloat16,device_map="auto",attn_implementation="sdpa")input_text = "LLMs generate text through a process known as"input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)past_key_values = DynamicCache(config=model.config)outputs = model.generate(**input_ids, max_new_tokens=50, past_key_values=past_key_values)print(tokenizer.decode(outputs[0], skip_special_tokens=True))
GemmaConfig
Section titled “GemmaConfig”[[autodoc]] GemmaConfig
GemmaTokenizer
Section titled “GemmaTokenizer”[[autodoc]] GemmaTokenizer
GemmaTokenizerFast
Section titled “GemmaTokenizerFast”[[autodoc]] GemmaTokenizerFast
GemmaModel
Section titled “GemmaModel”[[autodoc]] GemmaModel - forward
GemmaForCausalLM
Section titled “GemmaForCausalLM”[[autodoc]] GemmaForCausalLM - forward
GemmaForSequenceClassification
Section titled “GemmaForSequenceClassification”[[autodoc]] GemmaForSequenceClassification - forward
GemmaForTokenClassification
Section titled “GemmaForTokenClassification”[[autodoc]] GemmaForTokenClassification - forward