Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lm_engine/hf_models/model_conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def import_from_huggingface(
pretrained_model_name_or_path: str, save_path: str | None = None
pretrained_model_name_or_path: str, save_path: str | None = None, **kwargs
) -> tuple[GPTBaseConfig, GenerationConfig, AutoTokenizer, dict]:
original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path)
model_type = original_config.model_type
Expand All @@ -35,7 +35,7 @@ def import_from_huggingface(

config_import_function, state_dict_import_function = _MODEL_IMPORT_FUNCTIONS[model_type]

config = config_import_function(original_config)
config = config_import_function(original_config, **kwargs)

state_dict = state_dict_import_function(
config=config, safetensors_weights_manager=SafeTensorsWeightsManager(downloaded_model_path)
Expand Down
6 changes: 5 additions & 1 deletion lm_engine/hf_models/model_conversion/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from ..models import GPTBaseConfig


def _import_granite_config(original_config: GraniteConfig) -> GPTBaseConfig:
def _import_granite_config(original_config: GraniteConfig, **kwargs) -> GPTBaseConfig:
assert original_config.hidden_act == "silu"
assert original_config.mlp_bias == original_config.attention_bias
use_interleaved_weights = kwargs.pop("use_interleaved_weights", False)

config = GPTBaseConfig(
vocab_size=original_config.vocab_size,
Expand Down Expand Up @@ -47,11 +48,14 @@ def _import_granite_config(original_config: GraniteConfig) -> GPTBaseConfig:
"add_bias": original_config.mlp_bias,
"activation_function": "swiglu",
"intermediate_size": original_config.intermediate_size,
"use_interleaved_weights": use_interleaved_weights,
}
for _ in range(original_config.num_hidden_layers)
],
)

assert len(kwargs) == 0

return config


Expand Down
250 changes: 136 additions & 114 deletions lm_engine/hf_models/model_conversion/granitemoehybrid.py

Large diffs are not rendered by default.

152 changes: 83 additions & 69 deletions lm_engine/hf_models/model_conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from ..models import GPTBaseConfig


def _import_llama_config(original_config: LlamaConfig) -> GPTBaseConfig:
def _import_llama_config(original_config: LlamaConfig, **kwargs) -> GPTBaseConfig:
assert original_config.hidden_act == "silu"
assert original_config.mlp_bias == original_config.attention_bias
use_interleaved_weights = kwargs.pop("use_interleaved_weights", False)

config = GPTBaseConfig(
vocab_size=original_config.vocab_size,
Expand Down Expand Up @@ -50,11 +51,14 @@ def _import_llama_config(original_config: LlamaConfig) -> GPTBaseConfig:
"add_bias": original_config.mlp_bias,
"activation_function": "swiglu",
"intermediate_size": original_config.intermediate_size,
"use_interleaved_weights": use_interleaved_weights,
}
for _ in range(original_config.num_hidden_layers)
],
)

assert len(kwargs) == 0

return config


Expand All @@ -72,59 +76,62 @@ def _import_llama_state_dict(config: GPTBaseConfig, safetensors_weights_manager:
state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight")

for layer_idx in range(config.num_layers):
state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.input_layernorm.weight"
import_prefix = f"transformer.h.{layer_idx}."
export_prefix = f"model.layers.{layer_idx}."

use_interleaved_weights = config.mlp_blocks[layer_idx].use_interleaved_weights

state_dict[f"{import_prefix}ln_1.weight"] = safetensors_weights_manager.get_tensor(
f"{export_prefix}input_layernorm.weight"
)
state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.post_attention_layernorm.weight"
state_dict[f"{import_prefix}ln_2.weight"] = safetensors_weights_manager.get_tensor(
f"{export_prefix}post_attention_layernorm.weight"
)

state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc.weight"] = interleave_up_gate_tensor_for_mlp(
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.weight"),
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.weight"),
state_dict[f"{import_prefix}mlp_block.c_fc.weight"] = interleave_up_gate_tensor_for_mlp(
safetensors_weights_manager.get_tensor(f"{export_prefix}mlp.up_proj.weight"),
safetensors_weights_manager.get_tensor(f"{export_prefix}mlp.gate_proj.weight"),
is_interleaved=use_interleaved_weights,
)
if f"model.layers.{layer_idx}.mlp.up_proj.bias" in safetensors_weights_manager:
state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc.bias"] = interleave_up_gate_tensor_for_mlp(
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.bias"),
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.bias"),
if f"{export_prefix}mlp.up_proj.bias" in safetensors_weights_manager:
state_dict[f"{import_prefix}mlp_block.c_fc.bias"] = interleave_up_gate_tensor_for_mlp(
safetensors_weights_manager.get_tensor(f"{export_prefix}mlp.up_proj.bias"),
safetensors_weights_manager.get_tensor(f"{export_prefix}mlp.gate_proj.bias"),
is_interleaved=use_interleaved_weights,
)

state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj.weight"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.mlp.down_proj.weight"
state_dict[f"{import_prefix}mlp_block.c_proj.weight"] = safetensors_weights_manager.get_tensor(
f"{export_prefix}mlp.down_proj.weight"
)
if f"model.layers.{layer_idx}.mlp.down_proj.bias" in safetensors_weights_manager:
state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj.bias"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.mlp.down_proj.bias"
if f"{export_prefix}mlp.down_proj.bias" in safetensors_weights_manager:
state_dict[f"{import_prefix}mlp_block.c_proj.bias"] = safetensors_weights_manager.get_tensor(
f"{export_prefix}mlp.down_proj.bias"
)

state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_attn.weight"] = (
interleave_query_key_value_tensor_for_attention(
safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"),
safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"),
safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"),
state_dict[f"{import_prefix}sequence_mixer.c_attn.weight"] = interleave_query_key_value_tensor_for_attention(
safetensors_weights_manager.get_slice(f"{export_prefix}self_attn.q_proj.weight"),
safetensors_weights_manager.get_slice(f"{export_prefix}self_attn.k_proj.weight"),
safetensors_weights_manager.get_slice(f"{export_prefix}self_attn.v_proj.weight"),
num_attention_heads,
num_key_value_heads,
head_dim,
)
if f"{export_prefix}self_attn.q_proj.bias" in safetensors_weights_manager:
state_dict[f"{import_prefix}sequence_mixer.c_attn.bias"] = interleave_query_key_value_tensor_for_attention(
safetensors_weights_manager.get_slice(f"{export_prefix}self_attn.q_proj.bias"),
safetensors_weights_manager.get_slice(f"{export_prefix}self_attn.k_proj.bias"),
safetensors_weights_manager.get_slice(f"{export_prefix}self_attn.v_proj.bias"),
num_attention_heads,
num_key_value_heads,
head_dim,
)
)
if f"model.layers.{layer_idx}.self_attn.q_proj.bias" in safetensors_weights_manager:
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_attn.bias"] = (
interleave_query_key_value_tensor_for_attention(
safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.bias"),
safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.bias"),
safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.bias"),
num_attention_heads,
num_key_value_heads,
head_dim,
)
)

state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_proj.weight"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.self_attn.o_proj.weight"
state_dict[f"{import_prefix}sequence_mixer.c_proj.weight"] = safetensors_weights_manager.get_tensor(
f"{export_prefix}self_attn.o_proj.weight"
)
if f"model.layers.{layer_idx}.self_attn.o_proj.bias" in safetensors_weights_manager:
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_proj.bias"] = (
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.self_attn.o_proj.bias")
if f"{export_prefix}self_attn.o_proj.bias" in safetensors_weights_manager:
state_dict[f"{import_prefix}sequence_mixer.c_proj.bias"] = safetensors_weights_manager.get_tensor(
f"{export_prefix}self_attn.o_proj.bias"
)

return state_dict
Expand Down Expand Up @@ -180,59 +187,66 @@ def _export_llama_state_dict(config: GPTBaseConfig, safetensors_weights_manager:
state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight")

for layer_idx in range(config.num_layers):
state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.ln_1.weight"
import_prefix = f"transformer.h.{layer_idx}."
export_prefix = f"model.layers.{layer_idx}."

use_interleaved_weights = config.mlp_blocks[layer_idx].use_interleaved_weights

state_dict[f"{export_prefix}input_layernorm.weight"] = safetensors_weights_manager.get_tensor(
f"{import_prefix}ln_1.weight"
)
state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = (
safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight")
state_dict[f"{export_prefix}post_attention_layernorm.weight"] = safetensors_weights_manager.get_tensor(
f"{import_prefix}ln_2.weight"
)

up_weight, gate_weight = split_up_gate_tensor_for_mlp(
safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc.weight")
safetensors_weights_manager.get_tensor(f"{import_prefix}mlp_block.c_fc.weight"),
is_interleaved=use_interleaved_weights,
)
state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_weight
state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_weight
state_dict[f"{export_prefix}mlp.up_proj.weight"] = up_weight
state_dict[f"{export_prefix}mlp.gate_proj.weight"] = gate_weight

if f"transformer.h.{layer_idx}.mlp_block.c_fc.bias" in safetensors_weights_manager:
if f"{import_prefix}mlp_block.c_fc.bias" in safetensors_weights_manager:
up_bias, gate_bias = split_up_gate_tensor_for_mlp(
safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc.bias")
safetensors_weights_manager.get_tensor(f"{import_prefix}mlp_block.c_fc.bias"),
is_interleaved=use_interleaved_weights,
)
state_dict[f"model.layers.{layer_idx}.mlp.up_proj.bias"] = up_bias
state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.bias"] = gate_bias
state_dict[f"{export_prefix}mlp.up_proj.bias"] = up_bias
state_dict[f"{export_prefix}mlp.gate_proj.bias"] = gate_bias

state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.mlp_block.c_proj.weight"
state_dict[f"{export_prefix}mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor(
f"{import_prefix}mlp_block.c_proj.weight"
)
if f"transformer.h.{layer_idx}.mlp_block.c_proj.bias" in safetensors_weights_manager:
state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.mlp_block.c_proj.bias"
if f"{import_prefix}mlp_block.c_proj.bias" in safetensors_weights_manager:
state_dict[f"{export_prefix}mlp.down_proj.bias"] = safetensors_weights_manager.get_tensor(
f"{import_prefix}mlp_block.c_proj.bias"
)

query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention(
safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.sequence_mixer.c_attn.weight"),
safetensors_weights_manager.get_tensor(f"{import_prefix}sequence_mixer.c_attn.weight"),
num_attention_heads,
num_key_value_heads,
)
state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight
state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight
state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight
state_dict[f"{export_prefix}self_attn.q_proj.weight"] = query_weight
state_dict[f"{export_prefix}self_attn.k_proj.weight"] = key_weight
state_dict[f"{export_prefix}self_attn.v_proj.weight"] = value_weight

if f"transformer.h.{layer_idx}.sequence_mixer.c_attn.bias" in safetensors_weights_manager:
if f"{import_prefix}sequence_mixer.c_attn.bias" in safetensors_weights_manager:
query_bias, key_bias, value_bias = split_query_key_value_tensor_for_attention(
safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.sequence_mixer.c_attn.bias"),
safetensors_weights_manager.get_tensor(f"{import_prefix}sequence_mixer.c_attn.bias"),
num_attention_heads,
num_key_value_heads,
)
state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.bias"] = query_bias
state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.bias"] = key_bias
state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.bias"] = value_bias
state_dict[f"{export_prefix}self_attn.q_proj.bias"] = query_bias
state_dict[f"{export_prefix}self_attn.k_proj.bias"] = key_bias
state_dict[f"{export_prefix}self_attn.v_proj.bias"] = value_bias

state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.sequence_mixer.c_proj.weight"
state_dict[f"{export_prefix}self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor(
f"{import_prefix}sequence_mixer.c_proj.weight"
)
if f"transformer.h.{layer_idx}.sequence_mixer.c_proj.bias" in safetensors_weights_manager:
state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.sequence_mixer.c_proj.bias"
if f"{import_prefix}sequence_mixer.c_proj.bias" in safetensors_weights_manager:
state_dict[f"{export_prefix}self_attn.o_proj.bias"] = safetensors_weights_manager.get_tensor(
f"{import_prefix}sequence_mixer.c_proj.bias"
)

return state_dict
50 changes: 44 additions & 6 deletions lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,47 @@ def _get_std_for_linear(initializer_range: float, init_method: str, m_width: flo
return std


def interleave_up_gate_tensor_for_mlp(up_weight: torch.Tensor, gate_weight: torch.Tensor) -> torch.Tensor:
return torch.cat([up_weight, gate_weight])


def split_up_gate_tensor_for_mlp(c_fc_weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return c_fc_weight.chunk(2)
def interleave_up_gate_tensor_for_mlp(
up_weight: torch.Tensor, gate_weight: torch.Tensor, is_interleaved: bool, dim: int = 0
) -> torch.Tensor:
if is_interleaved:
if dim == 0:
W = torch.empty(
2 * up_weight.size(0), *up_weight.size()[1:], dtype=up_weight.dtype, device=up_weight.device
)
W[1::2] = up_weight
W[::2] = gate_weight
elif dim == 1:
W = torch.empty(
up_weight.size(0),
2 * up_weight.size(1),
*up_weight.size()[2:],
dtype=up_weight.dtype,
device=up_weight.device,
)
W[:, 1::2] = up_weight
W[:, ::2] = gate_weight
else:
raise ValueError
else:
W = torch.cat([up_weight, gate_weight], dim=dim)

return W


def split_up_gate_tensor_for_mlp(
c_fc_weight: torch.Tensor, is_interleaved: bool, dim: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
if is_interleaved:
if dim == 0:
u = c_fc_weight[1::2].contiguous()
g = c_fc_weight[::2].contiguous()
elif dim == 1:
u = c_fc_weight[:, 1::2].contiguous()
g = c_fc_weight[:, ::2].contiguous()
else:
raise ValueError
else:
u, g = c_fc_weight.chunk(2, dim=dim)

return u, g
Loading