Skip to content

Commit

Permalink
First try at integration tests:
Browse files Browse the repository at this point in the history
1. AttributeError: 'GriffinCausalLMOutput' object has no attribute 'attentions'.
2. `cache_position` not passed
  • Loading branch information
botev committed Apr 3, 2024
1 parent b8e4de4 commit 6085aa5
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 406 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def __init__(
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.num_attention_heads = num_heads
self.head_dim = self.hidden_size // self.num_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def forward(
queries = _apply_rope(queries, segment_pos)
keys = _apply_rope(keys, segment_pos)

cache = getattr(self, "cache", cache)
if cache is not None:
assert t == 1, f"When cache is provided only `t=1` is supported, not {t=}"

Expand Down Expand Up @@ -1200,7 +1201,7 @@ def __call__(
"""

bs, l, _ = x.shape
assert segment_pos.shape == (bs, l)
assert segment_pos.shape == (bs, l), segment_pos.shape
reset = segment_pos == 0

# Gates for x and a.
Expand Down Expand Up @@ -1838,7 +1839,9 @@ def forward(
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, GriffinOutput]:
print(kwargs)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
Expand Down Expand Up @@ -2001,7 +2004,9 @@ def __init__(self, config: RecurrentGemmaConfig):
super().__init__(config)
self.model = RecurrentGemmaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.width, config.vocab_size, bias=False)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
)

# Initialize weights and apply final processing
self.post_init()
Expand Down
Loading

0 comments on commit 6085aa5

Please sign in to comment.