Wyświetl na ai.google.dev | Uruchom w Google Colab | Wyświetl źródło w GitHubie |
Prezentujemy CodeGemma, zbiór modeli otwartego kodu opartych na modelach Gemma firmy Google DeepMind (Gemma Team i in., 2024 r). CodeGemma to rodzina lekkich, nowoczesnych modeli otwartych opartych na tych samych badaniach i technologii, które posłużyły do utworzenia modeli Gemini.
Bazując na wstępnie wytrenowanych modelach Gemma, modele CodeGemma są trenowane na podstawie ponad 500–1000 miliardów tokenów głównie kodu przy użyciu o tej samej architekturze co rodzina modeli Gemma. Dzięki temu modele CodeGemma osiągają najwyższą wydajność kodu zarówno podczas i generowania zadań, a jednocześnie rozumienia i rozumowania na dużą skalę.
CodeGemma ma 3 warianty:
- Wytrenowany model 7B z użyciem kodu
- Model kodu 7B dostrojony według instrukcji
- Model 2B wytrenowany specjalnie pod kątem uzupełniania kodu i generowania otwartego.
Z tego przewodnika dowiesz się, jak używać modelu CodeGemma z narzędziem Flax do uzupełniania kodu.
Konfiguracja
1. Konfigurowanie dostępu do Kaggle dla CodeGemma
Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji opisane w artykule Konfiguracja Gemma, z którego dowiesz się, jak:
- Dostęp do CodeGemma uzyskasz na stronie kaggle.com.
- Aby uruchomić model CodeGemma, wybierz środowisko wykonawcze Colab z wystarczającą ilością zasobów (GPU T4 ma niewystarczającą ilość pamięci – użyj TPU w wersji 2).
- Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.
Po zakończeniu konfiguracji Gemma przejdź do następnej sekcji, w której możesz ustawić zmienne środowiskowe dla środowiska Colab.
2. Ustawianie zmiennych środowiskowych
Ustaw zmienne środowiskowe dla interfejsów KAGGLE_USERNAME
i KAGGLE_KEY
. Kiedy pojawi się komunikat „Przyznać dostęp?”, Użytkownik wyraża zgodę na przyznanie tajnego dostępu.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. Zainstaluj bibliotekę gemma
Bezpłatna akceleracja sprzętowa Colab jest obecnie niewystarczająca do uruchomienia tego notatnika. Jeśli korzystasz z Colab Pay As You Go lub Colab Pro, kliknij Edytuj > Ustawienia notatnika > Wybierz GPU A100 > Zapisz, aby włączyć akcelerację sprzętową.
Następnie musisz zainstalować bibliotekę Google DeepMind gemma
ze strony github.com/google-deepmind/gemma
. Jeśli pojawi się błąd dotyczący resolvera zależności pip, zwykle możesz go zignorować.
pip install -q git+https://github.com/google-deepmind/gemma.git
4. Importuj biblioteki
Ten notatnik korzysta z usługi Gemma (która używa Flax do tworzenia warstw sieci neuronowych) i SentencePiece (do tokenizacji).
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
Wczytaj model CodeGemma
Wczytaj model CodeGemma za pomocą parametru kagglehub.model_download
, który przyjmuje 3 argumenty:
handle
: uchwyt modelu z Kagglepath
: (opcjonalny ciąg znaków) ścieżka lokalnaforce_download
: (opcjonalna wartość logiczna) wymusza ponowne pobranie modelu.
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
Sprawdź lokalizację wag modelu i tokenizatora, a następnie ustaw zmienne ścieżki. Katalog tokenizera znajduje się w katalogu głównym, z którego został pobrany model, a wagi modelu – w podkatalogu. Na przykład:
- Plik tokenizera
spm.model
znajdzie się w lokalizacji/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- Punkt kontrolny modelu będzie w:
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
Przeprowadź próbkowanie/wnioskowanie
Wczytaj i sformatuj punkt kontrolny modelu CodeGemma za pomocą metody gemma.params.load_and_format_params
:
params = params_lib.load_and_format_params(CKPT_PATH)
Wczytaj tokenizer CodeGemma utworzony za pomocą sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
Aby automatycznie wczytywać prawidłową konfigurację z punktu kontrolnego modelu CodeGemma, użyj narzędzia gemma.transformer.TransformerConfig
. Argument cache_size
to liczba kroków w pamięci podręcznej CodeGemma Transformer
. Następnie utwórz instancję modelu CodeGemma jako model_2b
za pomocą parametru gemma.transformer.Transformer
(dziedziczącego z flax.linen.Module
).
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
Utwórz sampler
w gemma.sampler.Sampler
. Wykorzystuje punkt kontrolny modelu CodeGemma i tokenizer.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
Utwórz zmienne reprezentujące tokeny typu Fill-in-the-middle (fim) oraz funkcje pomocnicze, aby sformatować prompt i wygenerowane dane wyjściowe.
Spójrzmy na przykład na ten kod:
def function(string):
assert function('asdf') == 'fdsa'
Chcielibyśmy wypełnić pole function
, tak aby potwierdzenie zawierało True
. W tym przypadku prefiks będzie wyglądał tak:
"def function(string):\n"
Sufiks będzie wyglądał tak:
"assert function('asdf') == 'fdsa'"
Następnie formatujemy go w postaci PREFIX-SUFFIX-MIDDLE (środkowa sekcja, którą należy wypełnić, znajduje się zawsze na końcu wiersza):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
Utwórz prompt i przeprowadź wnioskowanie. Określ tekst prefiksu before
i tekst sufiksu after
i wygeneruj sformatowany prompt za pomocą funkcji pomocniczej format_completion prompt
.
Możesz dostosować total_generation_steps
(liczbę kroków wykonanych podczas generowania odpowiedzi – w tym przykładzie użyto 100
do zachowania pamięci hosta).
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
Więcej informacji
- Więcej informacji o bibliotece Google DeepMind
gemma
znajdziesz na GitHubie, która zawiera ciągi dokumentów z modułami użytymi w tym samouczku, takie jakgemma.params
,gemma.transformer
orazgemma.sampler
. - Te biblioteki mają własne witryny z dokumentacją: core JAX, Flax i Orbax.
- Dokumentację usługi tokenizacji i detokenizera usługi
sentencepiece
znajdziesz w repozytorium Google na GitHubiesentencepiece
. - Dokumentację usługi
kagglehub
znajdziesz w witrynieREADME.md
w repozytorium GitHubkagglehub
firmy Kaggle. - Dowiedz się, jak używać modeli Gemma w Vertex AI Google Cloud.
- Jeśli używasz jednostek Google Cloud TPU (wersja 3-8 lub nowsza), zaktualizuj też pakiet
jax[tpu]
do najnowszej wersji (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
), uruchom ponownie środowisko wykonawcze i sprawdź, czy wersjejax
ijaxlib
są zgodne (!pip list | grep jax
). Może to zapobiec powstawaniu błędów typuRuntimeError
z powodu niezgodności wersjijaxlib
ijax
. Więcej instrukcji instalacji języka JAX znajdziesz w dokumentacji JAX.