diff --git a/pkg/hfutil/modelconfig/mistral.go b/pkg/hfutil/modelconfig/mistral.go index c05abc85a..efb087206 100644 --- a/pkg/hfutil/modelconfig/mistral.go +++ b/pkg/hfutil/modelconfig/mistral.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "os" + "strings" ) // MistralConfig defines the configuration for Mistral models @@ -117,7 +118,7 @@ func (c *MistralConfig) GetArchitecture() string { if len(c.Architectures) > 0 { return c.Architectures[0] } - return "MistralModel" + return "" } func (c *MistralConfig) GetModelType() string { @@ -140,9 +141,10 @@ func (c *MistralConfig) HasVision() bool { return false } -// IsEmbedding returns true since this is an embedding model +// IsEmbedding returns true when the architecture is "MistralModel", which indicates +// an embedding model (e.g. intfloat/e5-mistral-7b-instruct). func (c *MistralConfig) IsEmbedding() bool { - return true + return strings.EqualFold(c.GetArchitecture(), "MistralModel") } // Register the Mistral model handler diff --git a/pkg/hfutil/modelconfig/mistral_test.go b/pkg/hfutil/modelconfig/mistral_test.go index 4c3b54e39..e67a880c6 100644 --- a/pkg/hfutil/modelconfig/mistral_test.go +++ b/pkg/hfutil/modelconfig/mistral_test.go @@ -84,6 +84,65 @@ func TestLoadModelWithMistral(t *testing.T) { t.Logf("Mistral model parameter count via generic loader: %s", FormatParamCount(paramCount)) } +// TestMistralConfigNoArchitectures verifies that when the Architectures field +// is missing from config.json, GetArchitecture() returns an empty string rather +// than a misleading fallback like "MistralModel" which would cause the model to +// be misclassified as an embedding model by the config parser. +// Regression test for https://github.com/ome-projects/ome/issues/601 +func TestMistralConfigNoArchitectures(t *testing.T) { + configPath := filepath.Join("testdata", "mistral_no_architectures.json") + + config, err := LoadModelConfig(configPath) + if err != nil { + t.Fatalf("Failed to load Mistral config without architectures: %v", err) + } + + // GetArchitecture must return empty string when Architectures is absent + arch := config.GetArchitecture() + if arch != "" { + t.Errorf("Expected empty architecture when Architectures field is missing, got '%s'", arch) + } + + // IsEmbedding must return false for a generic Mistral model + if config.IsEmbedding() { + t.Error("Expected IsEmbedding() to return false for a generic Mistral model without explicit embedding architecture") + } + + // Verify it's still recognized as a Mistral model + if config.GetModelType() != "mistral" { + t.Errorf("Expected model type 'mistral', got '%s'", config.GetModelType()) + } +} + +// TestMistralConfigEmbedding verifies that a Mistral model with +// "architectures": ["MistralModel"] is correctly identified as an embedding +// model (e.g. intfloat/e5-mistral-7b-instruct). +// This is the load-bearing positive case for config_parser.go's EMBEDDING classification. +func TestMistralConfigEmbedding(t *testing.T) { + configPath := filepath.Join("testdata", "mistral_embedding.json") + + config, err := LoadModelConfig(configPath) + if err != nil { + t.Fatalf("Failed to load Mistral embedding config: %v", err) + } + + // GetArchitecture must return "MistralModel" + arch := config.GetArchitecture() + if arch != "MistralModel" { + t.Errorf("Expected architecture 'MistralModel', got '%s'", arch) + } + + // GetModelType must return "mistral" + if config.GetModelType() != "mistral" { + t.Errorf("Expected model type 'mistral', got '%s'", config.GetModelType()) + } + + // IsEmbedding must return true for MistralModel architecture + if !config.IsEmbedding() { + t.Error("Expected IsEmbedding() to return true for a model with 'MistralModel' architecture") + } +} + func TestMistralInstructConfig(t *testing.T) { configPath := filepath.Join("testdata", "mistral_7b_instruct.json") diff --git a/pkg/hfutil/modelconfig/testdata/mistral_embedding.json b/pkg/hfutil/modelconfig/testdata/mistral_embedding.json new file mode 100644 index 000000000..557da2000 --- /dev/null +++ b/pkg/hfutil/modelconfig/testdata/mistral_embedding.json @@ -0,0 +1,22 @@ +{ + "architectures": ["MistralModel"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 4096, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.36.0", + "use_cache": true, + "vocab_size": 32000 +} diff --git a/pkg/hfutil/modelconfig/testdata/mistral_no_architectures.json b/pkg/hfutil/modelconfig/testdata/mistral_no_architectures.json new file mode 100644 index 000000000..fcbca17f5 --- /dev/null +++ b/pkg/hfutil/modelconfig/testdata/mistral_no_architectures.json @@ -0,0 +1,21 @@ +{ + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mistral", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 4096, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.34.0.dev0", + "use_cache": true, + "vocab_size": 32000 +}