Skip to content

Commit

Permalink
Merge pull request #77 from s4rduk4r/offloading
Browse files Browse the repository at this point in the history
Offloading to cpu and disk
  • Loading branch information
casper-hansen authored Sep 27, 2023
2 parents 8793a9f + 841a231 commit f220ccf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
8 changes: 5 additions & 3 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
@classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=False) -> BaseAWQForCausalLM:
batch_size=1, safetensors=False,
max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)

return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors
)
fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder
)
19 changes: 15 additions & 4 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = tor
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM'):
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
Expand All @@ -153,22 +154,32 @@ def from_quantized(self, model_path, model_type, model_filename='',
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)

# Load checkpoint
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map
device_map=device_map,
offload_folder=offload_folder,
dtype=torch_dtype
)

# Dispath to devices
model = simple_dispatch_model(model, device_map)

if fuse_layers:
self.fuse_layers(model, quant_config)

# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=offload_folder
)


return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)

def _load_config(self, model_path, model_filename, safetensors=False,
Expand Down
2 changes: 1 addition & 1 deletion awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def apply_rotary_emb(
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
Expand Down

0 comments on commit f220ccf

Please sign in to comment.