前往 ai.google.dev 查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 |
概览
CodeGemma 是 Gemma 的一个变体,针对编码任务进行了优化。本教程以 Keras CodeGemma 快速入门为基础,介绍了 CodeGemma 可为编程任务提供帮助的更多方式。
设置
访问 CodeGemma
要完成本教程,您首先需要在 Gemma 设置中完成设置说明。Gemma 设置说明介绍了如何执行以下操作:
- 在 kaggle.com 上访问 Gemma。
- 请选择具有足够资源的 Colab 运行时来运行 Gemma 7B 模型。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 Gemma 设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。
选择运行时
如需运行 CodeGemma 7B 模型,您需要订阅付费 Colab Pro 方案,该方案可提供搭载 A100 GPU 的运行时。
- 在 Colab 窗口的右上角,选择 ▾(其他连接选项)。
- 选择更改运行时类型。
- 在硬件加速器下,选择 A100 GPU。
配置您的 API 密钥
要使用 Gemma,您必须提供您的 Kaggle 用户名和 Kaggle API 密钥。
要生成 Kaggle API 密钥,请前往您的 Kaggle 用户个人资料中的 Account(账号)标签页,然后选择 Create New Token(创建新令牌)。这会触发下载包含 API 凭据的 kaggle.json
文件。
在 Colab 中,选择左侧窗格中的 Secrets (🔑?),然后添加您的 Kaggle 用户名和 Kaggle API 密钥。将您的用户名存储在名称 KAGGLE_USERNAME
下,将 API 密钥存储在名称 KAGGLE_KEY
下。
设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。
import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
安装依赖项
pip install -q -U keras-nlp
选择一个后端
Keras 是一个高级的多框架深度学习 API,旨在实现简洁易用。使用 Keras 3,您可以在以下三种后端之一上运行工作流:TensorFlow、JAX 或 PyTorch。
在本教程中,请为 JAX 配置后端。
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
导入软件包
导入 Keras 和 KerasNLP。
import keras_nlp
import keras
# Run at half precision.
keras.config.set_floatx("bfloat16")
CodeGemma 7B 模型示例
本部分将举例说明如何使用预训练的 7B CodeGemma 模型来帮助完成编码任务。
加载模型
KerasNLP 使用 GemmaCausalLM
(一种用于因果语言建模的端到端 Gemma 模型)提供了所有三种 CodeGemma 变体(2B 和 7B 预训练 (PT) 和 7B 指令微调 (IT))的实现。因果语言模型会根据上一个词元预测下一个词元。
在此示例中,使用 from_preset
方法加载 code_gemma_7b_en
模型。
gemma_lm_7b = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_7b_en")
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/config.json... 100%|██████████| 556/556 [00:00<00:00, 790kB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/model.weights.h5... 100%|██████████| 15.9G/15.9G [02:39<00:00, 107MB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/tokenizer.json... 100%|██████████| 401/401 [00:00<00:00, 587kB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/assets/tokenizer/vocabulary.spm... 100%|██████████| 4.04M/4.04M [00:00<00:00, 16.4MB/s]
gemma_lm_7b.summary()
from_preset
方法会根据预设架构和权重对模型进行实例化。
使用多行 FIM 补全代码
PT CodeGemma 模型基于代码填充任务进行训练。本部分介绍了一些示例,这些示例使用 CodeGemma 的多行中间填充 (FIM) 功能根据周围的上下文在指定的光标位置自动填充代码。
首先,定义常量和提示格式设置辅助函数。
# Formatting control tokens to specify cursor location
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
# Define model stop tokens
END_TOKEN = gemma_lm_7b.preprocessor.tokenizer.end_token
stop_tokens = (BEFORE_CURSOR, AFTER_CURSOR, AT_CURSOR, FILE_SEPARATOR, END_TOKEN)
stop_token_ids = tuple(gemma_lm_7b.preprocessor.tokenizer.token_to_id(x) for x in stop_tokens)
def format_completion_prompt(before, after):
return f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
示例 1 - 插入缺失的条件
如果 n=1
,以下用于生成斐波那契序列的示例代码将无法正确执行:
def fibonacci(n: int) -> int:
if n == 0:
return 0
# The cursor is right before the e in the following line
else:
return fibonacci(n - 1) + fibonacci(n - 2)
假设游标位于第 4 行的开头(其中 else
子句位于),则游标前后的内容为:
before = """def fibonacci(n: int) -> int:\n if n == 0:\n return 0\n""" # Mind the spaces!
after = """\n else:\n return fibonacci(n - 1) + fibonacci(n-2)\n"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>def fibonacci(n: int) -> int: if n == 0: return 0 <|fim_suffix|> else: return fibonacci(n - 1) + fibonacci(n-2) <|fim_middle|>
运行提示。
print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>def fibonacci(n: int) -> int: if n == 0: return 0 <|fim_suffix|> else: return fibonacci(n - 1) + fibonacci(n-2) <|fim_middle|>elif n == 1: return 1<|file_separator|>
模型会在光标所在位置为 n=1
插入正确的 elif
配置。
示例 2 - 完整的 DFS 遍历算法
深度优先搜索 (DFS) 树遍历算法的自动补全代码。
before = """void dfs(node* root) {
if (root->left) {
dfs(root->left);
}"""
after = """\nprintf("%d", root->value);
}"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>void dfs(node* root) { if (root->left) { dfs(root->left); }<|fim_suffix|> printf("%d", root->value); }<|fim_middle|>
运行提示。
print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>void dfs(node* root) { if (root->left) { dfs(root->left); }<|fim_suffix|> printf("%d", root->value); }<|fim_middle|> if (root->right) { dfs(root->right); }<|file_separator|>
代码生成
除了代码填充之外,CodeGemma 7B PT 模型还使用自然语言语料库进行了训练。您可以使用此提示来提示模型生成代码。
generation_prompt= """Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {"""
print(gemma_lm_7b.generate(generation_prompt, max_length=500))
Write a rust function to identify non-prime numbers. Examples: >>> is_not_prime(2) False >>> is_not_prime(10) True pub fn is_not_prime(n: i32) -> bool { if n <= 1 { return true; } for i in 2..n { if n % i == 0 { return true; } } false }
70 亿 IT 模型示例
本部分使用 CodeGemma 7B 指令调优模型执行更高级的编码任务。CodeGemma 7B IT 模型由 CodeGemma 7B PT 模型衍生而来,它通过代码监督式微调和基于人类反馈的强化学习。本部分介绍了使用此模型进行开放式生成的示例。
加载 IT 模型
使用 from_preset
方法加载 code_gemma_instruct_7b_en
模型。
gemma_lm_7b_it = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_instruct_7b_en")
gemma_lm_7b_it.summary()
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/config.json... 100%|██████████| 556/556 [00:00<00:00, 754kB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/model.weights.h5... 100%|██████████| 15.9G/15.9G [03:18<00:00, 86.2MB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/tokenizer.json... 100%|██████████| 401/401 [00:00<00:00, 593kB/s] Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/assets/tokenizer/vocabulary.spm... 100%|██████████| 4.04M/4.04M [00:00<00:00, 16.8MB/s]
IT 模型使用特定的格式化程序进行训练,该程序会通过额外的信息标注所有指令调优示例,以指明角色并描述对话中的轮流。
首先,定义常量和提示格式设置辅助函数。
# Formatting control tokens for instruction tuning
START_OF_TURN_USER = "<start_of_turn>user"
END_OF_TURN = "<end_of_turn>"
START_OF_TURN_MODEL = "<start_of_turn>model"
# Formatting helper function
def format_instruction_prompt(context):
return f"{START_OF_TURN_USER}\n{context}{END_OF_TURN}\n{START_OF_TURN_MODEL}\n"
代码转换
context1 = """
You are an experienced C and Python programmer. Convert the following Python code into C.
```python
def factorial(n):
result = 1
for i in range(2, n + 1):
result *= i
return result
```\n"""
设置提示的格式。
prompt1 = format_instruction_prompt(context1)
print(prompt1)
<start_of_turn>user You are an experienced C and Python programmer. Convert the following Python code into C. ```python def factorial(n): result = 1 for i in range(2, n + 1): result *= i return result ``` <end_of_turn> <start_of_turn>model
运行提示。
print(gemma_lm_7b_it.generate(prompt1, max_length=500))
<start_of_turn>user You are an experienced C and Python programmer. Convert the following Python code into C. ```python def factorial(n): result = 1 for i in range(2, n + 1): result *= i return result ``` <end_of_turn> <start_of_turn>model Here is the C code equivalent of the Python code: ```c int factorial(int n) { int result = 1; for (int i = 2; i <= n; i++) { result *= i; } return result; } ``` Here is a breakdown of the changes: * The function is declared with the `int` return type, as in Python. * The `for` loop is converted to a `for` loop with an `int` variable `i` initialized to 2 and incremented by 1 in each iteration. * The `range` function is replaced with a simple loop that iterates from 2 to `n` (inclusive). * The `result *= i` statement is used to multiply `result` by `i` in each iteration. * The `return` statement is used to return the final value of `result`.
代码漏洞检测
context2 = """
You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.
```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;
numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p\n", WidgetList);
for (i = 0; i < numWidgets; i++) {
WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```\n"""
设置提示的格式。
prompt2 = format_instruction_prompt(context2)
print(prompt2)
<start_of_turn>user You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning. ```cpp int i; unsigned int numWidgets; Widget **WidgetList; numWidgets = GetUntrustedSizeValue(); if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) { ExitError("Incorrect number of widgets requested!"); } WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *)); printf("WidgetList ptr=%p ", WidgetList); for (i = 0; i < numWidgets; i++) { WidgetList[i] = InitializeWidget(); } WidgetList[numWidgets] = NULL; showWidgets(WidgetList); ``` <end_of_turn> <start_of_turn>model
print(gemma_lm_7b_it.generate(prompt2, max_length=1000))
<start_of_turn>user You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning. ```cpp int i; unsigned int numWidgets; Widget **WidgetList; numWidgets = GetUntrustedSizeValue(); if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) { ExitError("Incorrect number of widgets requested!"); } WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *)); printf("WidgetList ptr=%p ", WidgetList); for (i = 0; i < numWidgets; i++) { WidgetList[i] = InitializeWidget(); } WidgetList[numWidgets] = NULL; showWidgets(WidgetList); ``` <end_of_turn> <start_of_turn>model Yes, the code is vulnerable to a memory access error. **Reasoning:** * The code allocates memory for `WidgetList` using `malloc` based on the value of `numWidgets`. * However, the loop iterates from `0` to `numWidgets`, which is one element beyond the allocated memory. * This means that accessing `WidgetList[numWidgets]` will result in a memory access error, as it is outside the bounds of the allocated memory. **Example of Memory Access Error:** When `numWidgets` is 5, the code allocates memory for `WidgetList` as follows: ``` WidgetList = (Widget **) malloc(5 * sizeof(Widget *)); ``` The loop iterates from 0 to 4, accessing the following elements: * `WidgetList[0]` * `WidgetList[1]` * `WidgetList[2]` * `WidgetList[3]` * `WidgetList[4]` However, the code then attempts to access `WidgetList[5]`, which is outside the allocated memory range. This will result in a memory access error. **Solution:** To resolve this vulnerability, the loop should be modified to iterate from 0 to `numWidgets - 1`: ```cpp for (i = 0; i < numWidgets - 1; i++) { WidgetList[i] = InitializeWidget(); } ``` This ensures that the loop does not access elements beyond the allocated memory range.
该模型检测到代码中的潜在漏洞并提供代码更改来缓解该漏洞。
摘要
本教程介绍了如何使用 CodeGemma 执行各种编码任务。如需详细了解 CodeGemma,请执行以下操作:
- 如需了解 CodeGemma 模型的技术规范,请参阅 CodeGemma 模型卡片。
- 如需详细了解如何在 Vertex AI 中使用 CodeGemma,请点击此处。
- 查看 Keras CodeGemma 快速入门。