การเขียนโปรแกรมโดยใช้ AI ด้วย CodeGemma และ KerasNLP

ดูใน ai.google.dev เรียกใช้ใน Google Colab ดูแหล่งที่มาใน GitHub

ภาพรวม

CodeGemma เป็นตัวแปรของ Gemma ที่ได้รับการปรับแต่งให้เหมาะกับงานการเขียนโค้ด บทแนะนำนี้สร้างขึ้นจากการเริ่มต้นอย่างรวดเร็ว Keras CodeGemma และแสดงวิธีอื่นๆ ที่ CodeGemma สามารถช่วยงานเขียนโปรแกรมของคุณได้

ตั้งค่า

รับสิทธิ์เข้าถึง CodeGemma

หากต้องการจบบทแนะนำนี้ คุณจะต้องทำตามวิธีการตั้งค่าที่การตั้งค่า Gemma ก่อน วิธีการตั้งค่า Gemma จะแสดงวิธีดำเนินการต่อไปนี้

  • รับสิทธิ์เข้าถึง Gemma ใน kaggle.com
  • เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอที่จะเรียกใช้โมเดล Gemma 7B
  • สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle

หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปซึ่งจะตั้งค่าตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab

เลือกรันไทม์

หากต้องการเรียกใช้โมเดล CodeGemma 7B คุณจะต้องมีแพ็กเกจ Colab Pro แบบชำระเงินซึ่งมีรันไทม์พร้อม GPU A100

  1. ที่ด้านขวาบนของหน้าต่าง Colab ให้เลือก ▾ (ตัวเลือกการเชื่อมต่อเพิ่มเติม)
  2. เลือกเปลี่ยนประเภทรันไทม์
  3. ในส่วนตัวเร่งฮาร์ดแวร์ ให้เลือก GPU A100

กำหนดค่าคีย์ API

หากต้องการใช้ Gemma คุณต้องระบุชื่อผู้ใช้ Kaggle และคีย์ API ของ Kaggle

หากต้องการสร้างคีย์ Kaggle API ให้ไปที่แท็บ Account ของโปรไฟล์ผู้ใช้ Kaggle และเลือก Create New Token การดำเนินการนี้จะทริกเกอร์การดาวน์โหลดไฟล์ kaggle.json ที่มีข้อมูลเข้าสู่ระบบ API ของคุณ

ใน Colab ให้เลือก Secrets (🔑) ในแผงด้านซ้าย แล้วเพิ่มชื่อผู้ใช้ Kaggle และคีย์ API ของ Kaggle จัดเก็บชื่อผู้ใช้ในชื่อ KAGGLE_USERNAME และคีย์ API ในชื่อ KAGGLE_KEY

ตั้งค่าตัวแปรสภาพแวดล้อม

ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME และ KAGGLE_KEY

import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

ติดตั้งการอ้างอิง

pip install -q -U keras-nlp

เลือกแบ็กเอนด์

Keras เป็น API การเรียนรู้เชิงลึกที่มีหลายกรอบและมีระดับสูงซึ่งออกแบบมาให้ใช้งานง่าย เมื่อใช้ Keras 3 คุณจะเรียกใช้เวิร์กโฟลว์บนแบ็กเอนด์ 1 ใน 3 แบบ ได้แก่ TensorFlow, JAX หรือ PyTorch

กำหนดค่าแบ็กเอนด์สำหรับ JAX ในบทแนะนำนี้

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

นำเข้าแพ็กเกจ

นำเข้า Keras และ KerasNLP

import keras_nlp
import keras

# Run at half precision.
keras.config.set_floatx("bfloat16")

ตัวอย่างโมเดล CodeGemma 7B

ส่วนนี้จะครอบคลุมตัวอย่างการใช้โมเดล CodeGemma 7B ที่ฝึกล่วงหน้าเพื่อช่วยเรื่องการเขียนโค้ด

โหลดโมเดล

KerasNLP ติดตั้งใช้งานตัวแปร CodeGemma ทั้ง 3 แบบ (2B และ 7B ที่ฝึกล่วงหน้า (PT) และ 7B ในรูปแบบการสอน (IT) โดยใช้ GemmaCausalLM ซึ่งเป็นโมเดล Gemma จากต้นทางถึงปลายทางสำหรับโมเดลภาษาทั่วไป โมเดลภาษาทั่วไปจะคาดการณ์โทเค็นถัดไปตามโทเค็นก่อนหน้า

สำหรับตัวอย่างนี้ ให้โหลดโมเดล code_gemma_7b_en โดยใช้เมธอด from_preset

gemma_lm_7b = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_7b_en")
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 790kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [02:39<00:00, 107MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 587kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.4MB/s]
gemma_lm_7b.summary()

เมธอด from_preset จะสร้างอินสแตนซ์โมเดลจากสถาปัตยกรรมและน้ำหนักที่กำหนดไว้ล่วงหน้า

การเติมโค้ดด้วย FIM หลายบรรทัด

โมเดล PT CodeGemma ได้รับการฝึกเกี่ยวกับงานใส่โค้ด ส่วนนี้จะแสดงตัวอย่างที่ใช้ความสามารถในการเติมจุดกึ่งกลางแบบหลายบรรทัด (FIM) ของ CodeGemma ในการป้อนรหัสอัตโนมัติในตำแหน่งเคอร์เซอร์ที่ระบุตามบริบทโดยรอบ

ในขั้นตอนแรก ให้ระบุค่าคงที่และฟังก์ชันตัวช่วยการจัดรูปแบบพรอมต์

# Formatting control tokens to specify cursor location
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

# Define model stop tokens
END_TOKEN = gemma_lm_7b.preprocessor.tokenizer.end_token
stop_tokens = (BEFORE_CURSOR, AFTER_CURSOR, AT_CURSOR, FILE_SEPARATOR, END_TOKEN)
stop_token_ids = tuple(gemma_lm_7b.preprocessor.tokenizer.token_to_id(x) for x in stop_tokens)

def format_completion_prompt(before, after):
    return f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"

ตัวอย่างที่ 1 - แทรกเงื่อนไขที่ขาดหายไป

โค้ดตัวอย่างด้านล่างที่ใช้สร้างลำดับ Fibonacci จะไม่ทำงานอย่างถูกต้องหาก n=1:

def fibonacci(n: int) -> int:
  if n == 0:
    return 0
  # The cursor is right before the e in the following line
  else:
    return fibonacci(n - 1) + fibonacci(n - 2)

สมมติว่าเคอร์เซอร์อยู่ที่ต้นบรรทัดที่ 4 (โดยมีวลี else อยู่) เนื้อหาก่อนและหลังเคอร์เซอร์มีดังนี้

before = """def fibonacci(n: int) -> int:\n  if n == 0:\n    return 0\n""" # Mind the spaces!
after = """\n  else:\n    return fibonacci(n - 1) + fibonacci(n-2)\n"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>

เรียกใช้ข้อความแจ้ง

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>elif n == 1:
    return 1<|file_separator|>

โมเดลนี้จะแทรกเงื่อนไข elif ที่ถูกต้องสำหรับ n=1 ที่ตำแหน่งเคอร์เซอร์

ตัวอย่างที่ 2 - อัลกอริทึมการข้ามผ่าน DFS ที่สมบูรณ์

เติมโค้ดอัตโนมัติสําหรับอัลกอริทึม Tree Traversal แบบ Deep First Search (DFS)

before = """void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }"""
after = """\nprintf("%d", root->value);
}"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>

เรียกใช้ข้อความแจ้ง

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>
  if (root->right) {
    dfs(root->right);
  }<|file_separator|>

การสร้างโค้ด

นอกจากการเติมโค้ดแล้ว โมเดล CodeGemma 7B PT ยังได้รับการฝึกเกี่ยวกับคลังภาษาธรรมชาติด้วย คุณสามารถใช้โค้ดนี้เพื่อบอกให้โมเดลสร้างโค้ดได้

generation_prompt= """Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {"""
print(gemma_lm_7b.generate(generation_prompt, max_length=500))
Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {
    if n <= 1 {
        return true;
    }
    for i in 2..n {
        if n % i == 0 {
            return true;
        }
    }
    false
}

ตัวอย่างโมเดลไอที 7B

ส่วนนี้ใช้รูปแบบ CodeGemma 7B Instruction-Tuned สำหรับงานการเขียนโค้ดขั้นสูงขึ้น โมเดลไอทีของ CodeGemma 7B ได้มาจากโมเดล CodeGemma 7B PT ผ่านการปรับแต่งโค้ดโดยมีการควบคุมดูแล ตลอดจนการเรียนรู้แบบเสริมแรงกับความคิดเห็นจากมนุษย์ ส่วนนี้จะครอบคลุมตัวอย่างการใช้โมเดลนี้สำหรับรุ่นปลายเปิด

โหลดโมเดลไอที

โหลดโมเดล code_gemma_instruct_7b_en โดยใช้เมธอด from_preset

gemma_lm_7b_it = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_instruct_7b_en")
gemma_lm_7b_it.summary()
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 754kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [03:18<00:00, 86.2MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 593kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.8MB/s]

โมเดลไอทีจะได้รับการฝึกด้วยตัวจัดรูปแบบที่เจาะจงซึ่งจะทำคำอธิบายประกอบตัวอย่างการปรับแต่งคำสั่งทั้งหมดด้วยข้อมูลเพิ่มเติมเพื่อระบุบทบาทและอธิบายการเลี้ยวในการสนทนา

ในขั้นแรก ให้ระบุค่าคงที่และฟังก์ชันตัวช่วยการจัดรูปแบบพรอมต์

# Formatting control tokens for instruction tuning
START_OF_TURN_USER = "<start_of_turn>user"
END_OF_TURN = "<end_of_turn>"
START_OF_TURN_MODEL = "<start_of_turn>model"

# Formatting helper function
def format_instruction_prompt(context):
    return f"{START_OF_TURN_USER}\n{context}{END_OF_TURN}\n{START_OF_TURN_MODEL}\n"

การแปลโค้ด

context1 = """
You are an experienced C and Python programmer. Convert the following Python code into C.
```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```\n"""

จัดรูปแบบพรอมต์

prompt1 = format_instruction_prompt(context1)
print(prompt1)
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model

เรียกใช้ข้อความแจ้ง

print(gemma_lm_7b_it.generate(prompt1, max_length=500))
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model
Here is the C code equivalent of the Python code:

```c
int factorial(int n) {
  int result = 1;
  for (int i = 2; i <= n; i++) {
    result *= i;
  }
  return result;
}
```

Here is a breakdown of the changes:

* The function is declared with the `int` return type, as in Python.
* The `for` loop is converted to a `for` loop with an `int` variable `i` initialized to 2 and incremented by 1 in each iteration.
* The `range` function is replaced with a simple loop that iterates from 2 to `n` (inclusive).
* The `result *= i` statement is used to multiply `result` by `i` in each iteration.
* The `return` statement is used to return the final value of `result`.

การตรวจจับช่องโหว่ของโค้ด

context2 = """
You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.
```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p\n", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```\n"""

จัดรูปแบบพรอมต์

prompt2 = format_instruction_prompt(context2)
print(prompt2)
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
print(gemma_lm_7b_it.generate(prompt2, max_length=1000))
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
Yes, the code is vulnerable to a memory access error.

**Reasoning:**

* The code allocates memory for `WidgetList` using `malloc` based on the value of `numWidgets`.
* However, the loop iterates from `0` to `numWidgets`, which is one element beyond the allocated memory.
* This means that accessing `WidgetList[numWidgets]` will result in a memory access error, as it is outside the bounds of the allocated memory.

**Example of Memory Access Error:**

When `numWidgets` is 5, the code allocates memory for `WidgetList` as follows:

```
WidgetList = (Widget **) malloc(5 * sizeof(Widget *));
```

The loop iterates from 0 to 4, accessing the following elements:

* `WidgetList[0]`
* `WidgetList[1]`
* `WidgetList[2]`
* `WidgetList[3]`
* `WidgetList[4]`

However, the code then attempts to access `WidgetList[5]`, which is outside the allocated memory range. This will result in a memory access error.

**Solution:**

To resolve this vulnerability, the loop should be modified to iterate from 0 to `numWidgets - 1`:

```cpp
for (i = 0; i < numWidgets - 1; i++) {
    WidgetList[i] = InitializeWidget();
}
```

This ensures that the loop does not access elements beyond the allocated memory range.

โมเดลจะตรวจหาช่องโหว่ที่อาจเกิดขึ้นในโค้ดและทำการเปลี่ยนแปลงโค้ดเพื่อลดช่องโหว่

สรุป

บทแนะนำนี้จะแนะนำวิธีใช้ CodeGemma สำหรับงานเขียนโค้ดที่หลากหลาย วิธีดูข้อมูลเพิ่มเติมเกี่ยวกับ CodeGemma