Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add refact model #3329

Merged
merged 5 commits into from
Oct 4, 2023
Merged

add refact model #3329

merged 5 commits into from
Oct 4, 2023

Conversation

ds5t5
Copy link
Contributor

@ds5t5 ds5t5 commented Sep 25, 2023

example command (greedy) to test against huggingface.

python3 convert-refact-hf-to-gguf.py ./Refact-1_6B-fim 1

./main -m ./Refact-1_6B-fim/ggml-model-f16.gguf -n 300 -p "write a function to multiple two integers in python"  --temp 1.0 --top-p 1.0 --top-k 1 --repeat_penalty 1.0

resolve: #3061

Copy link
Collaborator

@Green-Sky Green-Sky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did not test.

on a side note: we end up with more and more duplicated code when we add more and more models. at least for the hf model loading in python.

convert-refact-hf-to-gguf.py Outdated Show resolved Hide resolved
@teleprint-me
Copy link
Contributor

teleprint-me commented Sep 25, 2023

did not test.

on a side note: we end up with more and more duplicated code when we add more and more models. at least for the hf model loading in python.

I tested it out and it's working as expected.

Any idea on how you'd like it refactored to reduce the duplicates down to a single convert.py?

If not, I could probably mock up some ideas once I have some resources and time.

As long as I have a bit of an idea (e.g. a track) of how to go about it or what might be expected, it shouldn't be much of a problem. Merging multiple interfaces into a single coherent interface in python is kind of my thing.

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 25, 2023

@teleprint-me it is probably worthy changing it after the special token issue is resolved. #2820. I think we had a few bugs in the current (all) converter when there is no added_token.json but special_tokens_map.json. wdyt?

@teleprint-me
Copy link
Contributor

teleprint-me commented Sep 26, 2023

@ds5t5

I think that attempting to support a variety of variable special token types is a challenging task. Probably not impossible, but also probably not worth pursuing.

A standard interface like ChatML would be better suited rather than attempting to adapt to a variety of variable special tokens that would be model specific.

There really isn't a "holy grail" solution to this. It's up to the dataset creators as well as the fine-tuners.

It's easier to create a generally abstracted interface that can accommodate and adapt to the variable structures that creators and developers would want to implement.

I mentioned something similar to this on a issue in llama-cpp-python because the bos and eos tokens are hard-coded into the method for creating chat completions.

This isn't an issue specific to llama.cpp. It's an issue that requires a generally agreed upon specification for everyone to agree to operate under. I don't see that happening for awhile though, so we'll see.

Creating a general interface for handling variable conversions would face a similar issue, would be manageable and generally worth pursing if only to reduce the number of front-end scripts and code duplication.

You could create a factory that would abstract it and then create the instance for converting the tensors and have a single front-end CLI interface as a result. This would be modular, approachable, maintainable, as well as extensible.

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 26, 2023

@teleprint-me i totally agree with you point that we should probably modularize the converter for HuggingFace one. do you think we could do it in another PR? i guess merging HuggingFace falcon, baichuan, refact and starcoder would be easier to start since convert.py also includes original llama pt version.

@Green-Sky
Copy link
Collaborator

I am getting a bunch of key xxx not in tokenizer vocabulary. padding with arbitrary token.

Details
./convert-refact-hf-to-gguf.py models/Refact-1_6B-fim/
gguf: loading model Refact-1_6B-fim
gguf: get model metadata
gguf: get tokenizer metadata
gguf: get gpt2 tokenizer vocab
Key 49152 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49153 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49154 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49155 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49156 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49157 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49158 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49159 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49160 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49161 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49162 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49163 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49164 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49165 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49166 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49167 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49168 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49169 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49170 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49171 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49172 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49173 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49174 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49175 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49176 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49177 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49178 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49179 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49180 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49181 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49182 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49183 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49184 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49185 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49186 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49187 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49188 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49189 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49190 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49191 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49192 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49193 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49194 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49195 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49196 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49197 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49198 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49199 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49200 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49201 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49202 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49203 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49204 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49205 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49206 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49207 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49208 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49209 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49210 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49211 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49212 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49213 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49214 not in tokenizer vocabulary. Padding with an arbitrary token.
Key 49215 not in tokenizer vocabulary. Padding with an arbitrary token.
gguf: Adding 48891 merge(s).
gguf: Setting special token type bos to 0
gguf: Setting special token type eos to 0
gguf: Setting special token type unk to 0
gguf: get tensor metadata
gguf: loading model part 'pytorch_model.bin'
......

@Green-Sky
Copy link
Collaborator

benchmarks:

cpu only with openblas:

$ llama-bench -m models/Refact-1_6B-fim/ggml-model-Q8_0.gguf -p 256 -p 512 -p 1024 -p 2048 -n 128 -n 256 -n 512 -n 1024 -n 2048

model size params backend threads test t/s
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 pp 256 45.62 ± 1.66
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 pp 512 29.69 ± 0.65
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 pp 1024 22.96 ± 0.07
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 pp 2048 15.62 ± 0.15
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 tg 128 25.86 ± 0.02
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 tg 256 25.41 ± 0.04
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 tg 512 24.56 ± 0.01
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 tg 1024 23.03 ± 0.02
Refact 1B mostly Q8_0 1.57 GiB 1.59 B BLAS 12 tg 2048 19.94 ± 0.87

gpu only cuda:

$ llama-bench -m models/Refact-1_6B-fim/ggml-model-Q8_0.gguf -p 256 -p 512 -p 1024 -p 2048 -n 128 -n 256 -n 512 -n 1024 -n 2048
ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 2070, compute capability 7.5

model size params backend ngl test t/s
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 pp 256 1048.71 ± 18.89
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 pp 512 657.43 ± 19.25
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 pp 1024 576.09 ± 18.19
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 pp 2048 463.34 ± 4.59
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 tg 128 89.19 ± 2.42
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 tg 256 92.51 ± 0.31
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 tg 512 88.78 ± 1.15
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 tg 1024 82.45 ± 0.61
Refact 1B mostly Q8_0 1.57 GiB 1.59 B CUDA 99 tg 2048 73.32 ± 0.48

@teleprint-me
Copy link
Contributor

@ds5t5

i totally agree with you point that we should probably modularize the converter for HuggingFace one.

This sounds like a good place to start. I usually prefer using a template to build off of.

do you think we could do it in another PR?

Sure, we can do it in another PR. I don't mind at all.

i guess merging HuggingFace falcon, baichuan, refact and starcoder would be easier to start since convert.py also includes original llama pt version.

I'm open to looking into this and creating a skeleton or writing up an outline to plan it out.

Let me know.

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 26, 2023

@Green-Sky that adding is expected to match the vocab size. it is similar to falcon converter code. btw, the CI is not passing due to this one https://github.com/ggerganov/llama.cpp/actions/runs/6309006961/job/17136638816?pr=3329.

Error: Waiting for VM to become ready timed out after 120 seconds

The rebasing is not helpful. Does anyone know how we can fix it? thanks. cc @Green-Sky @ggerganov

@slaren
Copy link
Collaborator

slaren commented Sep 26, 2023

It is a known issue, you can safely ignore the freeBSD CI failures.

Copy link
Collaborator

@Green-Sky Green-Sky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.
did some very basic testing, but no FIM. still waiting on #2934

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to adapt to the new llama_batch API and also replace ggml_alibi with ggml_add as we did for ggml_diag_mask_inf in #3228

llama.cpp/llama.cpp

Lines 3250 to 3262 in 0e76a89

switch (model.type) {
case MODEL_7B:
KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
break;
case MODEL_13B:
// TODO: replace with ggml_add()
KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8);
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
break;
default:
GGML_ASSERT(false);
}

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 29, 2023

@ggerganov i am getting incorrect result for refact (alibi) when switching to use the new KQ_mask construction outside for loop and ggml_add. only ggml_diag_mask_inf is giving the correct result compared with huggingface. have we verified baichuan 13b for the correctness on this new change?

@ggerganov
Copy link
Owner

No, baichuan 13B is known to not work at the moment.

The KQ_mask that is currently constructed is used only to set the infinite elements in the attention in order to select which tokens we want to attend to. We need a new ALiBi tensor (similar to KQ_mask) that contains the correct values based on batch.pos[i] to be added to KQ_scaled. I haven't done it because I don't have a model to test with handy.

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 29, 2023

@ggerganov i took my word back. it doesn't have a problem. i will push a new PR based on the latest branch

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 29, 2023

@ggerganov @Green-Sky i have pushed the new commit to rebase to the latest one. it will pass the metal gpu, however it will fail on CPU only mode with this error from ggml_compute_forward_alibi_f32. it looks like

GGML_ASSERT: ggml.c:12913: ne1 + n_past == ne0

i added another fix to remove the assert since we pass 0 as n_past in ggml_alibi now.

@ggerganov
Copy link
Owner

Does this implementation produce correct results? I think the ggml_alibi will be incorrect, since n_past is now always 0.

@ds5t5
Copy link
Contributor Author

ds5t5 commented Sep 29, 2023

@ggerganov i am getting correct results. i followed the example code here.

KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8);
.
do you mean we should actually set n_past instead of changing it to 0? And i tested it by setting it with kv_head (old n_past). It looks like the result is identical and no difference on speed.

@goerch
Copy link
Collaborator

goerch commented Sep 30, 2023

I see Refact is using a GPT2-based tokenizer. Would you care to check the impact of #3252 on this conversion (although I didn't have time to consider special_tokens_map.json yet)? Thanks!

@ggerganov
Copy link
Owner

@ggerganov i am getting correct results. i followed the example code here.

Ah interesting. I just realize that ggml_alibi does not use n_past for anything.
So the implementation was even simpler than I imagined. In any case, I want to take a bit deeper look - I think we want to deprecate ggml_alibi and use ggml_add to replace it. Will look into merging this early this week.

Might look into merging #3252 before that

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's resolve conflicts and merge

@ggerganov
Copy link
Owner

Hope I didn't break something with that last merge

@ggerganov ggerganov merged commit f8c90cd into ggerganov:master Oct 4, 2023
26 of 32 checks passed
ggerganov added a commit that referenced this pull request Oct 4, 2023
@martell
Copy link

martell commented Oct 4, 2023

Hi @ggerganov,
Can confirm that af19099 from an earlier commit in this PR works but 0d152b3 does not. Built the gguf using what was merged and used the same one with both commits
This is on an M2 Pro using Metal

llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:  226 tensors
error loading model: invalid character
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '../models/Refact-1_6B-fim/ggml-model-f16.gguf'
main: error: unable to load model

@ggerganov
Copy link
Owner

Thanks for reporting this - there are some ongoing tokenizer changes recently and things seem to be unstable. Similar issue was reported in #3484 - not sure if related

joelkuiper added a commit to vortext/llama.cpp that referenced this pull request Oct 5, 2023
…example

* 'master' of github.com:ggerganov/llama.cpp: (24 commits)
  convert : fix Baichuan2 models by using vocab size in config.json (ggerganov#3299)
  readme : add project status link
  ggml : fix build after ggerganov#3329
  llm : add Refact model (ggerganov#3329)
  sync : ggml (conv 1d + 2d updates, UB fixes) (ggerganov#3468)
  finetune : readme fix typo (ggerganov#3465)
  ggml : add RISC-V Vector Support for K-Quants and improved the existing intrinsics (ggerganov#3453)
  main : consistent prefix/suffix coloring (ggerganov#3425)
  llama : fix session saving/loading (ggerganov#3400)
  llama : expose model's rope_freq_scale in the API (ggerganov#3418)
  metal : alibi for arbitrary number of heads (ggerganov#3426)
  cmake : make LLAMA_NATIVE flag actually use the instructions supported by the processor (ggerganov#3273)
  Work on the BPE tokenizer (ggerganov#3252)
  convert : fix vocab size when not defined in hparams (ggerganov#3421)
  cmake : increase minimum version for add_link_options (ggerganov#3444)
  CLBlast: Add broadcast support for matrix multiplication (ggerganov#3402)
  gguf : add BERT, MPT, and GPT-J arch info (ggerganov#3408)
  gguf : general usability improvements (ggerganov#3409)
  cmake : make CUDA flags more similar to the Makefile (ggerganov#3420)
  finetune : fix ggerganov#3404 (ggerganov#3437)
  ...
@ggerganov
Copy link
Owner

ggerganov commented Oct 6, 2023

@martell I think you need to re-convert the model using the updated python script and it should work

@martell
Copy link

martell commented Oct 6, 2023

I should have shared a checksum in the original comment, I converted it at commit 0d152b3
The latest HEAD is giving me the same checksum and error.

llama.cpp

git log | head -1 
commit 1faaae8c2bdc4a21302e367e0754c3fe74a8113e

Refact-1_6B-fim

git log | head -1
commit acc9591f69aae4d950d58d372aa6c8b34543fd2c

converted using

python3 convert-refact-hf-to-gguf.py ../Refact-1_6B-fim 1
shasum -a 256 ../Refact-1_6B-fim/ggml-model-f16.gguf
73eb4b5a25d3c64fbfefbca332596b668bd22d5be66aa83d0496200e7ea5e59f ../Refact-1_6B-fim/ggml-model-f16.gguf
 ./main -m ../Refact-1_6B-fim/ggml-model-f16.gguf -n 300 -p "do something"  --temp 1.0 --top-p 1.0 --top-k 1 --repeat_penalty 1.0

...
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:  226 tensors
error loading model: invalid character
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '../Refact-1_6B-fim/ggml-model-f16.gguf'
main: error: unable to load model

yusiwen pushed a commit to yusiwen/llama.cpp that referenced this pull request Oct 7, 2023
* add refact model

* resolve comments

* rebase to the latest

* solve alibi cpu error

---------

Co-authored-by: Georgi Gerganov <[email protected]>
yusiwen pushed a commit to yusiwen/llama.cpp that referenced this pull request Oct 7, 2023
@ggerganov
Copy link
Owner

@martell Please test the branch in #3523

@martell
Copy link

martell commented Oct 8, 2023

@ggerganov Can confirm that it now runs with 42833bc

shasum -a 256 ../Refact-1_6B-fim/ggml-model-f16.gguf
c53008cce38590f602c0b04939c17da929968acae8ddf3672a2aff7082cf937e  ../Refact-1_6B-fim/ggml-model-f16.gguf

It seems to terminate early sometimes but I presume that is due to nans before alibi being discussed there.
I will have to read a lot more on the various internal naming of things to follow along more clearly.

./main -m ./Refact-1_6B-fim/ggml-model-f16.gguf -n 300 -p "write a function to multiple two integers in python"  --temp 1.0 --top-p 1.0 --top-k 1 --repeat_penalty 1.0
...
ggml_metal_init: hasUnifiedMemory              = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 10922.67 MB
ggml_metal_init: maxTransferRate               = built-in GPU
llama_new_context_with_model: compute buffer total size = 106.00 MB
llama_new_context_with_model: max tensor size =   192.25 MB
ggml_metal_add_buffer: allocated 'data            ' buffer, size =  3026.86 MB, ( 3027.36 / 10922.67)
ggml_metal_add_buffer: allocated 'kv              ' buffer, size =     6.00 MB, ( 3033.36 / 10922.67)
ggml_metal_add_buffer: allocated 'alloc           ' buffer, size =   100.14 MB, ( 3133.50 / 10922.67)

system_info: n_threads = 8 / 12 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | 
sampling: repeat_last_n = 64, repeat_penalty = 1.000000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 1, tfs_z = 1.000000, top_p = 1.000000, typical_p = 1.000000, temp = 1.000000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000
generate: n_ctx = 512, n_batch = 512, n_predict = 300, n_keep = 0


write a function to multiple two integers in python<|endoftext|> [end of text]

@ggerganov
Copy link
Owner

It seems to terminate early sometimes but I presume that is due to nans before alibi being discussed there.

This should be resolved in 42833bc
Can you double-check that you have built the commit make clean && make? It should no longer terminate early, so it would be unexpected if it happens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority Very important issue model Model specific
Projects
None yet
Development

Successfully merging this pull request may close these issues.

llama : add Refact support
7 participants