CodeGemma와 KerasNLP를 사용한 AI 지원 프로그래밍

ai.google.dev에서 보기 Google Colab에서 실행 GitHub에서 소스 보기

개요

CodeGemma는 코딩 작업에 맞춰 미세 조정된 Gemma의 변종입니다. 이 가이드는 Keras CodeGemma 빠른 시작을 기반으로 하며 CodeGemma가 프로그래밍 작업을 지원하는 다양한 방법을 보여줍니다.

설정

CodeGemma에 액세스하기

이 튜토리얼을 완료하려면 먼저 Gemma 설정에서 설정 안내를 완료해야 합니다. Gemma 설정 안내에서는 다음 작업을 수행하는 방법을 보여줍니다.

  • kaggle.com에서 Gemma에 액세스하세요.
  • Gemma 7B 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택하세요.
  • Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

런타임 선택

CodeGemma 7B 모델을 실행하려면 A100 GPU가 포함된 런타임을 제공하는 Colab Pro 유료 요금제를 이용해야 합니다.

  1. Colab 창 오른쪽 상단에서 ▾ (추가 연결 옵션)를 선택합니다.
  2. 런타임 유형 변경을 선택합니다.
  3. 하드웨어 가속기에서 A100 GPU를 선택합니다.

API 키 구성

Gemma를 사용하려면 Kaggle 사용자 이름과 Kaggle API 키를 제공해야 합니다.

Kaggle API 키를 생성하려면 Kaggle 사용자 프로필의 계정 탭으로 이동하여 새 토큰 만들기를 선택합니다. 이렇게 하면 API 사용자 인증 정보가 포함된 kaggle.json 파일의 다운로드가 트리거됩니다.

Colab에서 왼쪽 창에 있는 Secrets (VC)를 선택하고 Kaggle 사용자 이름과 Kaggle API 키를 추가합니다. 사용자 이름을 KAGGLE_USERNAME 이름으로, API 키를 KAGGLE_KEY 이름 아래에 저장합니다.

환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다.

import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

종속 항목 설치

pip install -q -U keras-nlp

백엔드 선택

Keras는 단순성과 사용 편의성을 위해 설계된 높은 수준의 다중 프레임워크 딥 러닝 API입니다. Keras 3을 사용하면 TensorFlow, JAX, PyTorch의 세 가지 백엔드 중 하나에서 워크플로를 실행할 수 있습니다.

이 튜토리얼에서는 JAX 백엔드를 구성합니다.

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".

패키지 가져오기

Keras와 KerasNLP를 가져옵니다.

import keras_nlp
import keras

# Run at half precision.
keras.config.set_floatx("bfloat16")

CodeGemma 7B 모델 예

이 섹션에서는 코딩 작업에 도움이 되도록 선행 학습된 7B CodeGemma 모델을 사용하는 예시를 다룹니다.

모델 로드

KerasNLP는 인과 언어 모델링을 위한 엔드 투 엔드 Gemma 모델인 GemmaCausalLM를 사용하여 CodeGemma 변형 (2B 및 7B 선행 학습 (PT) 및 7B 명령 조정 (IT))의 세 가지 구현을 제공합니다. 인과 언어 모델은 이전 토큰을 기반으로 다음 토큰을 예측합니다.

이 예시에서는 from_preset 메서드를 사용하여 code_gemma_7b_en 모델을 로드합니다.

gemma_lm_7b = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_7b_en")
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 790kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [02:39<00:00, 107MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 587kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.4MB/s]
gemma_lm_7b.summary()

from_preset 메서드는 사전 설정된 아키텍처 및 가중치에서 모델을 인스턴스화합니다.

여러 줄의 FIM을 사용한 코드 완성

PT CodeGemma 모델은 코드 입력 작업에 대해 학습됩니다. 이 섹션에서는 CodeGemma의 여러 줄로 된 중간 채우기 (FIM) 기능을 사용하여 주변 컨텍스트에 따라 지정된 커서 위치에서 코드를 자동 완성하는 예를 보여줍니다.

첫 번째 단계로 상수를 정의하고 프롬프트 형식 지정 도우미 함수를 정의합니다.

# Formatting control tokens to specify cursor location
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

# Define model stop tokens
END_TOKEN = gemma_lm_7b.preprocessor.tokenizer.end_token
stop_tokens = (BEFORE_CURSOR, AFTER_CURSOR, AT_CURSOR, FILE_SEPARATOR, END_TOKEN)
stop_token_ids = tuple(gemma_lm_7b.preprocessor.tokenizer.token_to_id(x) for x in stop_tokens)

def format_completion_prompt(before, after):
    return f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"

예 1 - 누락된 조건 삽입

n=1인 경우 피보나치 수열을 생성하는 아래의 예시 코드는 올바르게 실행되지 않습니다.

def fibonacci(n: int) -> int:
  if n == 0:
    return 0
  # The cursor is right before the e in the following line
  else:
    return fibonacci(n - 1) + fibonacci(n - 2)

커서가 else 절이 있는 4번째 줄의 시작 부분에 있다고 가정하면 커서 앞뒤의 내용은 다음과 같습니다.

before = """def fibonacci(n: int) -> int:\n  if n == 0:\n    return 0\n""" # Mind the spaces!
after = """\n  else:\n    return fibonacci(n - 1) + fibonacci(n-2)\n"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>

프롬프트를 실행합니다.

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>def fibonacci(n: int) -> int:
  if n == 0:
    return 0
<|fim_suffix|>
  else:
    return fibonacci(n - 1) + fibonacci(n-2)
<|fim_middle|>elif n == 1:
    return 1<|file_separator|>

모델이 커서 위치에 n=1의 올바른 elif 전환을 삽입합니다.

예 2 - 완전한 DFS 순회 알고리즘

깊이 우선 검색 (DFS) 트리 순회 알고리즘의 자동 완성 코드

before = """void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }"""
after = """\nprintf("%d", root->value);
}"""
prompt = format_completion_prompt(before, after)
print(prompt)
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>

프롬프트를 실행합니다.

print(gemma_lm_7b.generate(prompt, stop_token_ids=stop_token_ids, max_length=128))
<|fim_prefix|>void dfs(node* root) {
  if (root->left) {
    dfs(root->left);
  }<|fim_suffix|>
printf("%d", root->value);
}<|fim_middle|>
  if (root->right) {
    dfs(root->right);
  }<|file_separator|>

코드 생성

코드 입력 외에도 CodeGemma 7B PT 모델은 자연어 코퍼스로도 학습됩니다. 이를 사용하여 모델에 코드를 생성하라는 메시지를 표시할 수 있습니다.

generation_prompt= """Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {"""
print(gemma_lm_7b.generate(generation_prompt, max_length=500))
Write a rust function to identify non-prime numbers.
Examples:
>>> is_not_prime(2)
False
>>> is_not_prime(10)
True
pub fn is_not_prime(n: i32) -> bool {
    if n <= 1 {
        return true;
    }
    for i in 2..n {
        if n % i == 0 {
            return true;
        }
    }
    false
}

70억 IT 모델 예시

이 섹션에서는 고급 코딩 작업에 CodeGemma 7B 명령 조정 모델을 사용합니다. CodeGemma 7B IT 모델은 인간 피드백 기반 강화 학습과 함께 코드에 대한 지도 미세 조정을 통해 CodeGemma 7B PT 모델에서 파생됩니다. 이 섹션에서는 이 모델을 서술형 생성에 사용하는 예를 설명합니다.

IT 모델 로드

from_preset 메서드를 사용하여 code_gemma_instruct_7b_en 모델을 로드합니다.

gemma_lm_7b_it = keras_nlp.models.GemmaCausalLM.from_preset("code_gemma_instruct_7b_en")
gemma_lm_7b_it.summary()
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/config.json...
100%|██████████| 556/556 [00:00<00:00, 754kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/model.weights.h5...
100%|██████████| 15.9G/15.9G [03:18<00:00, 86.2MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/tokenizer.json...
100%|██████████| 401/401 [00:00<00:00, 593kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_instruct_7b_en/1/download/assets/tokenizer/vocabulary.spm...
100%|██████████| 4.04M/4.04M [00:00<00:00, 16.8MB/s]

IT 모델은 특정 형식 지정 도구로 학습됩니다. 이 형식은 대화에서 역할을 나타내고 차례를 설명하는 추가 정보를 사용하여 모든 명령 조정 예시에 주석을 추가합니다.

첫 번째 단계로 상수를 정의하고 프롬프트 형식 지정 도우미 함수를 정의합니다.

# Formatting control tokens for instruction tuning
START_OF_TURN_USER = "<start_of_turn>user"
END_OF_TURN = "<end_of_turn>"
START_OF_TURN_MODEL = "<start_of_turn>model"

# Formatting helper function
def format_instruction_prompt(context):
    return f"{START_OF_TURN_USER}\n{context}{END_OF_TURN}\n{START_OF_TURN_MODEL}\n"

코드 번역

context1 = """
You are an experienced C and Python programmer. Convert the following Python code into C.
```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```\n"""

프롬프트의 형식을 지정합니다.

prompt1 = format_instruction_prompt(context1)
print(prompt1)
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model

프롬프트를 실행합니다.

print(gemma_lm_7b_it.generate(prompt1, max_length=500))
<start_of_turn>user

You are an experienced C and Python programmer. Convert the following Python code into C.

```python
def factorial(n):
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
```
<end_of_turn>
<start_of_turn>model
Here is the C code equivalent of the Python code:

```c
int factorial(int n) {
  int result = 1;
  for (int i = 2; i <= n; i++) {
    result *= i;
  }
  return result;
}
```

Here is a breakdown of the changes:

* The function is declared with the `int` return type, as in Python.
* The `for` loop is converted to a `for` loop with an `int` variable `i` initialized to 2 and incremented by 1 in each iteration.
* The `range` function is replaced with a simple loop that iterates from 2 to `n` (inclusive).
* The `result *= i` statement is used to multiply `result` by `i` in each iteration.
* The `return` statement is used to return the final value of `result`.

코드 취약점 감지

context2 = """
You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.
```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p\n", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```\n"""

프롬프트의 형식을 지정합니다.

prompt2 = format_instruction_prompt(context2)
print(prompt2)
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
print(gemma_lm_7b_it.generate(prompt2, max_length=1000))
<start_of_turn>user

You are an experienced C++ programmer hunting for vulnerable code. Is the following code vulnerable? Explain your reasoning.

```cpp
int i;
unsigned int numWidgets;
Widget **WidgetList;

numWidgets = GetUntrustedSizeValue();
if ((numWidgets == 0) || (numWidgets > MAX_NUM_WIDGETS)) {
    ExitError("Incorrect number of widgets requested!");
}
WidgetList = (Widget **) malloc(numWidgets * sizeof(Widget *));
printf("WidgetList ptr=%p
", WidgetList);
for (i = 0; i < numWidgets; i++) {
    WidgetList[i] = InitializeWidget();
}
WidgetList[numWidgets] = NULL;
showWidgets(WidgetList);
```
<end_of_turn>
<start_of_turn>model
Yes, the code is vulnerable to a memory access error.

**Reasoning:**

* The code allocates memory for `WidgetList` using `malloc` based on the value of `numWidgets`.
* However, the loop iterates from `0` to `numWidgets`, which is one element beyond the allocated memory.
* This means that accessing `WidgetList[numWidgets]` will result in a memory access error, as it is outside the bounds of the allocated memory.

**Example of Memory Access Error:**

When `numWidgets` is 5, the code allocates memory for `WidgetList` as follows:

```
WidgetList = (Widget **) malloc(5 * sizeof(Widget *));
```

The loop iterates from 0 to 4, accessing the following elements:

* `WidgetList[0]`
* `WidgetList[1]`
* `WidgetList[2]`
* `WidgetList[3]`
* `WidgetList[4]`

However, the code then attempts to access `WidgetList[5]`, which is outside the allocated memory range. This will result in a memory access error.

**Solution:**

To resolve this vulnerability, the loop should be modified to iterate from 0 to `numWidgets - 1`:

```cpp
for (i = 0; i < numWidgets - 1; i++) {
    WidgetList[i] = InitializeWidget();
}
```

This ensures that the loop does not access elements beyond the allocated memory range.

모델이 코드의 잠재적인 취약점을 감지하고 이를 완화할 수 있는 코드 변경사항을 제공합니다.

요약

이 튜토리얼에서는 다양한 코딩 작업에 CodeGemma를 사용하는 방법을 안내했습니다. CodeGemma에 대해 자세히 알아보려면 다음 단계를 따르세요.