Skip to content

Commit

Permalink
Merge branch 'hiyouga:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Truecodeman authored Oct 24, 2024
2 parents b4e7445 + b4c7dd3 commit 175ae09
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 5 deletions.
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/wechat_npu.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions src/llamafactory/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
Expand Down Expand Up @@ -47,7 +48,7 @@

IGNORE_INDEX = -100

IMAGE_PLACEHOLDER = "<image>"
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")

LAYERNORM_NAMES = {"norm", "ln"}

Expand Down Expand Up @@ -95,7 +96,7 @@

SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}

VIDEO_PLACEHOLDER = "<video>"
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")

V_HEAD_WEIGHTS_NAME = "value_head.bin"

Expand Down
3 changes: 2 additions & 1 deletion src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ def load_model(
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM

if model_args.train_from_scratch:
model = load_class.from_config(config)
model = load_class.from_config(config, trust_remote_code=True)
else:
model = load_class.from_pretrained(**init_kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/train/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is False
else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-2, atol=1e-3) is True


def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
Expand Down

0 comments on commit 175ae09

Please sign in to comment.