diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 99568449..77f62f5a 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -32,15 +32,21 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t n_vocab: int = emb_weight.shape[0] n_embed: int = emb_weight.shape[1] - is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict - is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict - - if is_v5_2: - print('Detected RWKV v5.2') - elif is_v5_1_or_2: - print('Detected RWKV v5.1') - else: - print('Detected RWKV v4') + version = 4 + keys = list(state_dict.keys()) + for k in keys: + if 'ln_x' in k: + version = max(5, version) + if 'gate.weight' in k: + version = max(5.1, version) + if int(version) == 5 and 'att.time_decay' in k: + if len(state_dict[k].shape) > 1: + if (state_dict[k].shape[1]) > 1: + version = max(5.2, version) + if "time_maa" in k: + version = max(6, version) + + print(f'Model detected v{version:.1f}') with open(dest_path, 'wb') as out_file: is_FP16: bool = data_type == 'FP16' or data_type == 'float16' @@ -57,15 +63,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t 1 if is_FP16 else 0 )) - for k in state_dict.keys(): + keys = list(state_dict.keys()) + for k in keys: tensor: torch.Tensor = state_dict[k].float() if '.time_' in k: tensor = tensor.squeeze() - if is_v5_1_or_2: + if int(version) == 5: if '.time_decay' in k: - if is_v5_2: + if version == 5.2: tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1) else: tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)