Skip to content

Commit

Permalink
rebase to the latest
Browse files Browse the repository at this point in the history
  • Loading branch information
ds5t5 committed Sep 29, 2023
1 parent 8b8c6d5 commit af19099
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 55 deletions.
23 changes: 21 additions & 2 deletions convert-refact-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -235,6 +233,27 @@ def parse_args() -> argparse.Namespace:
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(dir_model / part_name, map_location="cpu")

for i in range(block_count):
if f"transformer.h.{i}.attn.kv.weight" in model_part:
data = model_part[f"transformer.h.{i}.attn.kv.weight"]
model_part[f"model.layers.{i}.self_attn.k_proj.weight"] = data[
: n_head_kv * head_dim
]
model_part[f"model.layers.{i}.self_attn.v_proj.weight"] = data[
n_head_kv * head_dim :
]
del model_part[f"transformer.h.{i}.attn.kv.weight"]
if f"transformer.h.{i}.attn.q.weight" in model_part:
model_part[f"model.layers.{i}.self_attn.q_proj.weight"] = model_part[
f"transformer.h.{i}.attn.q.weight"
]
del model_part[f"transformer.h.{i}.attn.q.weight"]
if f"transformer.h.{i}.mlp.gate_up_proj.weight" in model_part:
data = model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
model_part[f"model.layers.{i}.mlp.gate_proj.weight"] = data[:ff_dim]
model_part[f"model.layers.{i}.mlp.up_proj.weight"] = data[ff_dim:]
del model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"]

for name in model_part.keys():
data = model_part[name]

Expand Down
9 changes: 2 additions & 7 deletions gguf-py/gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,21 +286,18 @@ class TensorNameMap:
# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf
"transformer.h.{bid}.attn.q", # refact
"layers.{bid}.attention.wq", # llama-pth
),

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf
"transformer.h.{bid}.attn.k", # refact
"layers.{bid}.attention.wk", # llama-pth
),

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf
"transformer.h.{bid}.attn.v", # refact
"layers.{bid}.attention.wv", # llama-pth
),

Expand Down Expand Up @@ -335,15 +332,13 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.c_fc", # gpt2
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"model.layers.{bid}.mlp.up_proj", # llama-hf
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
"layers.{bid}.feed_forward.w3", # llama-pth
"transformer.h.{bid}.mlp.linear_3", # refact
),

# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth
),

Expand Down
110 changes: 64 additions & 46 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3369,25 +3369,18 @@ static struct ggml_cgraph * llm_build_baichaun(

static struct ggml_cgraph * llm_build_refact(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
int n_tokens,
int n_past) {

GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT

const int N = n_tokens;

const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;

const auto & kv_self = lctx.kv_self;

GGML_ASSERT(!!kv_self.ctx);

const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
Expand All @@ -3397,6 +3390,12 @@ static struct ggml_cgraph * llm_build_refact(

const int n_gpu_layers = model.n_gpu_layers;

const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;

// printf("n_kv = %d\n", n_kv);

auto & buf_compute = lctx.buf_compute;

struct ggml_init_params params = {
Expand All @@ -3414,12 +3413,12 @@ static struct ggml_cgraph * llm_build_refact(
struct ggml_tensor * cur;
struct ggml_tensor * inpL;

if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);

ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
}
ggml_set_name(inp_tokens, "inp_tokens");

Expand All @@ -3429,11 +3428,11 @@ static struct ggml_cgraph * llm_build_refact(
GGML_ASSERT(false && "not implemented");
#endif

inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);

ggml_allocr_alloc(lctx.alloc, inpL);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
}
}

Expand All @@ -3442,9 +3441,6 @@ static struct ggml_cgraph * llm_build_refact(

// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
//
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
// in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
Expand All @@ -3461,12 +3457,36 @@ static struct ggml_cgraph * llm_build_refact(
}
#endif // GGML_USE_CUBLAS

// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
}

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];

for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
}
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");

for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il);
Expand Down Expand Up @@ -3504,36 +3524,33 @@ static struct ggml_cgraph * llm_build_refact(
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");

struct ggml_tensor * Kcur;
struct ggml_tensor * Qcur;
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N);
Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N);

struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");

struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");

// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
// compute the transposed [n_tokens, n_embd] V matrix

struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");

struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur");

struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
offload_func_kq(k);
ggml_set_name(k, "k");

struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v");

Expand All @@ -3547,7 +3564,7 @@ static struct ggml_cgraph * llm_build_refact(

struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_past + N, n_head_kv,
n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
Expand All @@ -3560,25 +3577,28 @@ static struct ggml_cgraph * llm_build_refact(
ggml_set_name(KQ, "KQ");

// KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
// KQ_scaled shape [n_kv, n_tokens, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");

struct ggml_tensor * KQ_masked;
struct ggml_tensor * KQ_scaled_alibi;

KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8);
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);

struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");

// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");

// split cached V into n_head heads
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,
n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
Expand All @@ -3593,7 +3613,7 @@ static struct ggml_cgraph * llm_build_refact(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif

Expand All @@ -3602,10 +3622,8 @@ static struct ggml_cgraph * llm_build_refact(
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");

Expand Down Expand Up @@ -4338,7 +4356,7 @@ static struct ggml_cgraph * llama_build_graph(
} break;
case LLM_ARCH_REFACT:
{
result = llm_build_refact(lctx, tokens, embd, n_tokens, n_past);
result = llm_build_refact(lctx, batch);
} break;
default:
GGML_ASSERT(false);
Expand Down

0 comments on commit af19099

Please sign in to comment.