Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offloading to cpu and disk #77

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
@classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
batch_size=1, safetensors=False,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)

return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors
)
fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder
)
19 changes: 15 additions & 4 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = tor
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
Expand All @@ -153,22 +154,32 @@ def from_quantized(self, model_path, model_type, model_filename='',
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)

# Load checkpoint
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map
device_map=device_map,
offload_folder=offload_folder,
dtype=torch_dtype
)

# Dispath to devices
model = simple_dispatch_model(model, device_map)

if fuse_layers:
self.fuse_layers(model, quant_config)

# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=offload_folder
)


return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)

def _load_config(self, model_path, model_filename, safetensors=False,
Expand Down
2 changes: 1 addition & 1 deletion awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def apply_rotary_emb(
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
Expand Down