Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Conversion Issues + add support for remote upload. #694

Merged
merged 7 commits into from
Aug 21, 2024

Conversation

soldni
Copy link
Member

@soldni soldni commented Aug 7, 2024

No description provided.

@soldni soldni marked this pull request as draft August 7, 2024 06:24
@soldni soldni requested a review from 2015aroras August 7, 2024 23:47
@soldni soldni marked this pull request as ready for review August 7, 2024 23:47
Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for improving this script!

# this takes care of the case where the model was saved with a different prefix,
# typically due to unsharding.
common_prefix = longest_common_prefix(state_dict.keys())
state_dict = {key.replace(common_prefix, "transformer."): val for key, val in state_dict.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You could merge this and the next line with

new_state_dict = {key.replace(common_prefix, f"{OLMoForCausalLM.base_model_prefix}.transformer."): val for key, val in state_dict.items()}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

convert_checkpoint(local_model_path)
return local_model_path
try:
Tokenizer.from_pretrained(tokenizer_name_or_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Tokenizer.from_checkpoint(checkpoint_dir) would be better. It can handle local tokenizer jsons (like what we save in our repo) as well as HF pretrained checkpoints. Tbh I'm surprised if this even works if the config says something like tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanna be able to run this script outside the OLMo directory (that is, just after installing ai2-olmo via pip), and in the case I wouldn't have json vocabulary file in path (not particularly fond of those being in the olmo repo anyway).

gonna re-giggle this a bit to support both.

Copy link
Collaborator

@2015aroras 2015aroras Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed Tokenizer.from_checkpoint to work with pip-installed ai2-olmo (in particular, I handled the scenario where the tokenizer field is something like tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json). If it doesn't work in that scenario, please let me know so that I can fix it as some point.

I actually forgot to put my change in Tokenizer.from_file, so I'm surprised if pip-installed olmo works with your current code.

def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: str | None = None):

# Create S3 client
s3_client = boto3.client("s3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like the following may be more convenient for setups that use multiple profiles.

Suggested change
s3_client = boto3.client("s3")
s3_client = util._get_s3_client("s3")


# Initialize the progress bar
with tqdm(total=len(files_to_download), desc="Downloading files") as pbar:
for s3_key in files_to_download:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does something like the following not work? It would get rid of the need for pbar.update(1), and remove 1 level of nesting

for s3_key in tqdm(files_to_download, desc="Downloading files"):

print(f"Downloading checkpoint to {temp_dir}...")
download_s3_directory(
bucket_name=parsed_dir.netloc,
prefix=parsed_dir.path[1:],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to remove the first letter? If it's to get rid of /, then lstrip("/") might indicate your intent more clearly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatgpt code 🤪


print(f"Unsharding {checkpoint_dir}...")
train_config = TrainConfig.load(os.path.join(checkpoint_dir, "config.yaml"))
checkpointer = OlmoCoreCheckpointer(train_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
checkpointer = OlmoCoreCheckpointer(train_config)
checkpointer = build_sharded_checkpointer(train_config)

This would allow other types of sharded checkpoints to work here.



@contextmanager
def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str) -> Generator[None, None, None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason this needs to be a context manager? You could just call it once all the converting is done?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right; fixed. made sense before I removed functionality. useless now.



@contextmanager
def make_local_checkpoint(checkpoint_dir: str) -> Generator[str, None, None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason this needs to be a context manager? You could just call at the start?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right; fixed.

@soldni soldni merged commit 4575d40 into main Aug 21, 2024
11 of 12 checks passed
@soldni soldni deleted the soldni/ddp-conversion-fix branch August 21, 2024 06:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants