Programmazione assistita dall'IA con CodeGemma e KerasNLP

Visualizza su ai.google.dev Esegui in Google Colab Visualizza il codice sorgente su GitHub

Panoramica

CodeGemma è una variante di Gemma ottimizzata per le attività di programmazione. Questo tutorial si basa sulla guida rapida di Keras CodeGemma e mostra altri modi in cui CodeGemma può supportare le tue attività di programmazione.

Configurazione

Accedi a CodeGemma

Per completare questo tutorial, devi prima completare le istruzioni di configurazione nella configurazione di Gemma. Le istruzioni di configurazione di Gemma mostrano come fare:

  • Accedi a Gemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma 7B.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di Gemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.

Seleziona il runtime

Per eseguire i modelli CodeGemma 7B, devi avere un piano Colab Pro a pagamento che fornisce un runtime con una GPU A100.

  1. Nell'angolo in alto a destra della finestra Colab, seleziona ▾ (Opzioni di connessione aggiuntive).
  2. Seleziona Modifica tipo di runtime.
  3. In Acceleratore hardware, seleziona GPU A100.

Configura la chiave API

Per utilizzare Gemma, devi fornire il tuo nome utente Kaggle e una chiave API Kaggle.

Per generare una chiave API Kaggle, vai alla scheda Account del tuo profilo utente Kaggle e seleziona Create New Token (Crea nuovo token). Verrà attivato il download di un file kaggle.json contenente le tue credenziali API.

In Colab, seleziona Secret (webinar) nel riquadro a sinistra e aggiungi il tuo nome utente Kaggle e la tua chiave API Kaggle. Memorizza il tuo nome utente sotto il nome KAGGLE_USERNAME e la chiave API sotto il nome KAGGLE_KEY.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY.

import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installa le dipendenze

pip install -q -U keras-nlp

Seleziona un servizio di backend

Keras è un'API di deep learning multi-framework di alto livello progettata per la semplicità e la facilità d'uso. Con Keras 3, puoi eseguire flussi di lavoro su uno dei tre backend: TensorFlow, JAX o PyTorch.

Per questo tutorial, configura il backend per JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

Importa pacchetti

Importare Keras e KerasNLP.

import keras_nlp
import keras

# Run at half precision.
keras.config.set_floatx("bfloat16")

Esempi di modelli CodeGemma 7B

Questa sezione tratta esempi di utilizzo del modello preaddestrato 7B CodeGemma per agevolare le attività di programmazione.

Carica il modello

KerasNLP fornisce implementazioni di tutte e tre le varianti di CodeGemma (2B e 7B preaddestrate (PT) e 7B ottimizzate per l'istruzione (IT)) utilizzando GemmaCausalLM, un modello Gemma end-to-end per la modellazione del linguaggio causale. Un modello linguistico causale prevede il token successivo in base a quelli precedenti.

Per questo esempio, carica il modello code_gemma_7b_en utilizzando il metodo from_preset.

gemma_lm_7b = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_7b_en")
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 790kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [02:39<00:00, 107MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 587kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.4MB/s]
gemma_lm_7b.summary()

Il metodo from_preset crea un'istanza del modello da un'architettura e pesi predefiniti.

Completamento del codice con FIM multilinea

I modelli PT CodeGemma sono addestrati su attività di compilazione del codice. Questa sezione mostra esempi che utilizzano la funzionalità di compilazione automatica multilinea (FIM) di CodeGemma per compilare automaticamente il codice nella posizione del cursore specificata in base al contesto circostante.

Come primo passaggio, definisci le costanti e una funzione helper per la formattazione dei prompt.

# Formatting control tokens to specify cursor location
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

# Define model stop tokens
END_TOKEN = gemma_lm_7b.preprocessor.tokenizer.end_token
stop_tokens = (BEFORE_CURSOR, AFTER_CURSOR, AT_CURSOR, FILE_SEPARATOR, END_TOKEN)
stop_token_ids = tuple(gemma_lm_7b.preprocessor.tokenizer.token_to_id(x) for x in stop_tokens)

def format_completion_prompt(before, after):
    return f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"

Esempio 1: inserire una condizione mancante

Il codice di esempio riportato di seguito per generare la sequenza di Fibonacci non verrà eseguito correttamente se n=1:

def fibonacci(n: int) -> int:
  if n == 0:
    return 0
  # The cursor is right before the e in the following line
  else:
    return fibonacci(n - 1) + fibonacci(n - 2)

Supponendo che il cursore si trovi all'inizio della riga 4 (dove si trova la clausola else), i contenuti prima e dopo il cursore sono:

before = """def fibonacci(n: int) -> int:\n  if n == 0:\n    return 0\n""" # Mind the spaces!
after = """\n  else:\n    return fibonacci(n - 1) + fibonacci(n-2)\n"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>

Esegui il prompt.

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>elif n == 1:
    return 1<|file_separator|>

Il modello inserisce la condizione elif corretta per n=1 nella posizione del cursore.

Esempio 2: algoritmo di attraversamento DFS completo

Codice di completamento automatico per un algoritmo di attraversamento ad albero della ricerca DFS (depth-first).

before = """void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }"""
after = """\nprintf("%d", root->value);
}"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>

Esegui il prompt.

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>
  if (root->right) {
    dfs(root->right);
  }<|file_separator|>

Generazione del codice

Oltre al code infill, il modello CodeGemma 7B PT è anche addestrato su corpus di linguaggio naturale. Puoi utilizzarlo per richiedere al modello di generare il codice.

generation_prompt= """Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {"""
print(gemma_lm_7b.generate(generation_prompt, max_length=500))
Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {
    if n <= 1 {
        return true;
    }
    for i in 2..n {
        if n % i == 0 {
            return true;
        }
    }
    false
}

Esempi di modelli IT di 7 miliardi

Questa sezione utilizza il modello ottimizzato in base alle istruzioni di CodeGemma 7B per attività di programmazione più avanzate. Il modello IT di CodeGemma 7B è derivato dal modello PT CodeGemma 7B attraverso un'ottimizzazione supervisionata del codice insieme all'apprendimento per rinforzo con feedback umano. Questa sezione illustra esempi di utilizzo di questo modello per la generazione aperta.

Carica il modello IT

Carica il modello code_gemma_instruct_7b_en usando il metodo from_preset.

gemma_lm_7b_it = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_instruct_7b_en")
gemma_lm_7b_it.summary()
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 754kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [03:18<00:00, 86.2MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 593kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.8MB/s]

I modelli IT sono addestrati con un formattatore specifico che annota tutti gli esempi di ottimizzazione delle istruzioni con informazioni aggiuntive per indicare i ruoli e delineare i turni di una conversazione.

Come primo passaggio, definisci le costanti e una funzione helper per la formattazione dei prompt.

# Formatting control tokens for instruction tuning
START_OF_TURN_USER = "<start_of_turn>user"
END_OF_TURN = "<end_of_turn>"
START_OF_TURN_MODEL = "<start_of_turn>model"

# Formatting helper function
def format_instruction_prompt(context):
    return f"{START_OF_TURN_USER}\n{context}{END_OF_TURN}\n{START_OF_TURN_MODEL}\n"

Traduzione del codice

context1 = """
You are an experienced C and Python programmer. Convert the following Python code into C.
```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```\n"""

Formatta il prompt.

prompt1 = format_instruction_prompt(context1)
print(prompt1)
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model

Esegui il prompt.

print(gemma_lm_7b_it.generate(prompt1, max_length=500))
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model
Here is the C code equivalent of the Python code:

```c
int factorial(int n) {
  int result = 1;
  for (int i = 2; i <= n; i++) {
    result *= i;
  }
  return result;
}
```

Here is a breakdown of the changes:

* The function is declared with the `int` return type, as in Python.
* The `for` loop is converted to a `for` loop with an `int` variable `i` initialized to 2 and incremented by 1 in each iteration.
* The `range` function is replaced with a simple loop that iterates from 2 to `n` (inclusive).
* The `result *= i` statement is used to multiply `result` by `i` in each iteration.
* The `return` statement is used to return the final value of `result`.

Rilevamento delle vulnerabilità del codice

context2 = """
You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.
```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p\n", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```\n"""

Formatta il prompt.

prompt2 = format_instruction_prompt(context2)
print(prompt2)
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
print(gemma_lm_7b_it.generate(prompt2, max_length=1000))
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
Yes, the code is vulnerable to a memory access error.

**Reasoning:**

* The code allocates memory for `WidgetList` using `malloc` based on the value of `numWidgets`.
* However, the loop iterates from `0` to `numWidgets`, which is one element beyond the allocated memory.
* This means that accessing `WidgetList[numWidgets]` will result in a memory access error, as it is outside the bounds of the allocated memory.

**Example of Memory Access Error:**

When `numWidgets` is 5, the code allocates memory for `WidgetList` as follows:

```
WidgetList = (Widget **) malloc(5 * sizeof(Widget *));
```

The loop iterates from 0 to 4, accessing the following elements:

* `WidgetList[0]`
* `WidgetList[1]`
* `WidgetList[2]`
* `WidgetList[3]`
* `WidgetList[4]`

However, the code then attempts to access `WidgetList[5]`, which is outside the allocated memory range. This will result in a memory access error.

**Solution:**

To resolve this vulnerability, the loop should be modified to iterate from 0 to `numWidgets - 1`:

```cpp
for (i = 0; i < numWidgets - 1; i++) {
    WidgetList[i] = InitializeWidget();
}
```

This ensures that the loop does not access elements beyond the allocated memory range.

Il modello rileva una potenziale vulnerabilità nel codice e apporta modifiche al codice per mitigarla.

Riepilogo

Questo tutorial ti ha illustrato l'utilizzo di CodeGemma per diverse attività di programmazione. Per scoprire di più su CodeGemma: