透過 CodeGemma 和 KerasNLP 進行 AI 輔助程式設計

前往 ai.google.dev 查看 在 Google Colab 中執行 前往 GitHub 查看原始碼

總覽

CodeGemma 是 Gemma 的變化版本,專為程式設計工作微調。本教學課程以 Keras CodeGemma 快速入門導覽課程為基礎,會提供更多 CodeGemma 輔助程式設計工作的方法。

設定

取得 CodeGemma 的存取權

如要完成本教學課程,您必須先前往 Gemma 設定頁面完成設定。Gemma 設定操作說明會說明如何執行下列操作:

  • 前往 kaggle.com 存取 Gemma。
  • 選取具備足夠資源來執行 Gemma 7B 模型的 Colab 執行階段。
  • 產生並設定 Kaggle 使用者名稱和 API 金鑰。

完成 Gemma 設定後,請繼續前往下一節,設定 Colab 環境的環境變數。

選取執行階段

如要執行 CodeGemma 7B 模型,你必須使用付費 Colab Pro 方案,才能使用搭載 A100 GPU 的執行階段。

  1. 選取 Colab 視窗右上方的「▾」(其他連線選項)
  2. 選取「變更執行階段類型」
  3. 在「Hardware accelerator」(硬體加速器) 下方,選取「A100 GPU」

設定 API 金鑰

如要使用 Gemma,請提供 Kaggle 使用者名稱和 Kaggle API 金鑰。

如要產生 Kaggle API 金鑰,請前往 Kaggle 使用者設定檔的「Account」分頁,然後選取「Create New Token」。這會觸發下載內含 API 憑證的 kaggle.json 檔案。

在 Colab 中,選取左側窗格中的「Secrets」 (🔑?),然後新增 Kaggle 使用者名稱和 Kaggle API 金鑰。將您的使用者名稱儲存在「KAGGLE_USERNAME」這個名稱下,並將 API 金鑰儲存在 KAGGLE_KEY 名稱下。

設定環境變數

設定 KAGGLE_USERNAMEKAGGLE_KEY 的環境變數。

import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安裝依附元件

pip install -q -U keras-nlp

選取後端

Keras 是高階的多架構深度學習 API,專為簡化使用而設計。使用 Keras 3,您可以在 TensorFlow、JAX 或 PyTorch 這三個後端之一執行工作流程。

在本教學課程中,請設定 JAX 的後端。

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

匯入套件

匯入 Keras 和 KerasNLP。

import keras_nlp
import keras

# Run at half precision.
keras.config.set_floatx("bfloat16")

CodeGemma 7B 模型範例

本節將舉例說明,如何使用預先訓練的 7B CodeGemma 模型協助程式設計工作。

載入模型

KerasNLP 使用 GemmaCausalLM (適用於因果語言模型的端對端語言模型) 的端對端 Gemma 模型,提供三種全部的 CodeGemma 變化版本 (20 億個和 70 億個預先訓練 (PT) 和 70 億個指令調整 (IT)) 的實作作業。因果語言模型會根據先前的符記預測下一個符記。

在這個範例中,請使用 from_preset 方法載入 code_gemma_7b_en 模型。

gemma_lm_7b = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_7b_en")
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 790kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [02:39<00:00, 107MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 587kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.4MB/s]
gemma_lm_7b.summary()

from_preset 方法會根據預設的架構和權重將模型例項化。

使用多行 FIM 完成程式碼

PT CodeGemma 模型是根據程式碼填入工作訓練而成。本節舉例說明,如何使用 CodeGemma 的中間多行填滿 (FIM) 功能,根據上下文在指定的遊標位置自動填入程式碼。

首先,請定義常數和提示格式輔助函式。

# Formatting control tokens to specify cursor location
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

# Define model stop tokens
END_TOKEN = gemma_lm_7b.preprocessor.tokenizer.end_token
stop_tokens = (BEFORE_CURSOR, AFTER_CURSOR, AT_CURSOR, FILE_SEPARATOR, END_TOKEN)
stop_token_ids = tuple(gemma_lm_7b.preprocessor.tokenizer.token_to_id(x) for x in stop_tokens)

def format_completion_prompt(before, after):
    return f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"

示例 1 - 插入缺少的狀況

n=1,以下用於產生費波那契序列的程式碼範例將無法正確執行:

def fibonacci(n: int) -> int:
  if n == 0:
    return 0
  # The cursor is right before the e in the following line
  else:
    return fibonacci(n - 1) + fibonacci(n - 2)

假設遊標位於第 4 行開頭 (其中 else 子句所在位置),則遊標前後的內容為:

before = """def fibonacci(n: int) -> int:\n  if n == 0:\n    return 0\n""" # Mind the spaces!
after = """\n  else:\n    return fibonacci(n - 1) + fibonacci(n-2)\n"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>

執行提示。

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>elif n == 1:
    return 1<|file_separator|>

模型會在遊標位置插入 n=1 的正確 elif 連接。

範例 2:完整 DFS 週遊演算法

針對深度優先搜尋 (DFS) 樹狀結構週遊演算法的自動完成程式碼。

before = """void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }"""
after = """\nprintf("%d", root->value);
}"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>

執行提示。

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>
  if (root->right) {
    dfs(root->right);
  }<|file_separator|>

產生程式碼

除了填入程式碼外,CodeGemma 7B PT 模型也使用自然語言語料庫進行訓練您可以使用此提示給模型生成程式碼。

generation_prompt= """Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {"""
print(gemma_lm_7b.generate(generation_prompt, max_length=500))
Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {
    if n <= 1 {
        return true;
    }
    for i in 2..n {
        if n % i == 0 {
            return true;
        }
    }
    false
}

70 億個 IT 模型示例

本節使用 CodeGemma 7B instructions-Tuned 模型進行更進階的程式設計工作。CodeGemma 7B IT 模型衍生自 CodeGemma 70 億個 PT 模型,並經過監督微調,搭配使用者意見回饋強化學習。本節會舉例說明將這個模型用於開放式生成技術。

載入 IT 模型

使用 from_preset 方法載入 code_gemma_instruct_7b_en 模型。

gemma_lm_7b_it = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_instruct_7b_en")
gemma_lm_7b_it.summary()
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 754kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [03:18<00:00, 86.2MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 593kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.8MB/s]

IT 模型是以特定格式設定工具訓練而成,所有操作說明調整範例都會加註額外資訊,指出在對話中代表角色及劃分輪廓。

首先,請定義常數和提示格式輔助函式。

# Formatting control tokens for instruction tuning
START_OF_TURN_USER = "<start_of_turn>user"
END_OF_TURN = "<end_of_turn>"
START_OF_TURN_MODEL = "<start_of_turn>model"

# Formatting helper function
def format_instruction_prompt(context):
    return f"{START_OF_TURN_USER}\n{context}{END_OF_TURN}\n{START_OF_TURN_MODEL}\n"

程式碼轉譯

context1 = """
You are an experienced C and Python programmer. Convert the following Python code into C.
```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```\n"""

設定提示格式。

prompt1 = format_instruction_prompt(context1)
print(prompt1)
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model

執行提示。

print(gemma_lm_7b_it.generate(prompt1, max_length=500))
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model
Here is the C code equivalent of the Python code:

```c
int factorial(int n) {
  int result = 1;
  for (int i = 2; i <= n; i++) {
    result *= i;
  }
  return result;
}
```

Here is a breakdown of the changes:

* The function is declared with the `int` return type, as in Python.
* The `for` loop is converted to a `for` loop with an `int` variable `i` initialized to 2 and incremented by 1 in each iteration.
* The `range` function is replaced with a simple loop that iterates from 2 to `n` (inclusive).
* The `result *= i` statement is used to multiply `result` by `i` in each iteration.
* The `return` statement is used to return the final value of `result`.

程式碼安全漏洞偵測

context2 = """
You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.
```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p\n", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```\n"""

設定提示格式。

prompt2 = format_instruction_prompt(context2)
print(prompt2)
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
print(gemma_lm_7b_it.generate(prompt2, max_length=1000))
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
Yes, the code is vulnerable to a memory access error.

**Reasoning:**

* The code allocates memory for `WidgetList` using `malloc` based on the value of `numWidgets`.
* However, the loop iterates from `0` to `numWidgets`, which is one element beyond the allocated memory.
* This means that accessing `WidgetList[numWidgets]` will result in a memory access error, as it is outside the bounds of the allocated memory.

**Example of Memory Access Error:**

When `numWidgets` is 5, the code allocates memory for `WidgetList` as follows:

```
WidgetList = (Widget **) malloc(5 * sizeof(Widget *));
```

The loop iterates from 0 to 4, accessing the following elements:

* `WidgetList[0]`
* `WidgetList[1]`
* `WidgetList[2]`
* `WidgetList[3]`
* `WidgetList[4]`

However, the code then attempts to access `WidgetList[5]`, which is outside the allocated memory range. This will result in a memory access error.

**Solution:**

To resolve this vulnerability, the loop should be modified to iterate from 0 to `numWidgets - 1`:

```cpp
for (i = 0; i < numWidgets - 1; i++) {
    WidgetList[i] = InitializeWidget();
}
```

This ensures that the loop does not access elements beyond the allocated memory range.

模型會偵測程式碼中是否有潛在安全漏洞,並提供程式碼變更來降低風險。

摘要

本教學課程將逐步引導您使用 CodeGemma 執行各種程式設計工作。如要進一步瞭解 CodeGemma: