Skip to content

Commit 8fa368e

Browse files
committed
fix failed testcases
1 parent f8e6347 commit 8fa368e

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

tests/input/test_ernie_vl_processor.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def setUp(self):
149149
self.mock_tokenizer.convert_tokens_to_ids.side_effect = self._mock_convert_tokens_to_ids
150150
self.mock_tokenizer.chat_template = "mock_template"
151151
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello<|image@placeholder|>"
152+
# Mock encode method for _add_text
153+
self.mock_tokenizer.encode = MagicMock(return_value={"input_ids": [1, 2, 3]})
152154

153155
def mock_load_tokenizer(dp_instance):
154156
dp_instance.tokenizer = self.mock_tokenizer
@@ -168,6 +170,7 @@ def mock_load_tokenizer(dp_instance):
168170
self.data_processor.video_end_id = 1005
169171
self.data_processor.role_prefixes = {"user": "User: ", "assistant": "Assistant: "}
170172
self.data_processor.enable_processor_cache = False
173+
# Note: extract_mm_items is not mocked by default, only when needed
171174
self.data_processor.extract_mm_items = MagicMock(return_value=([], [], [], [], None, [], []))
172175

173176
def _mock_convert_tokens_to_ids(self, token):
@@ -196,7 +199,7 @@ def test_prompt_token_ids2outputs_only_prompt_token_ids(self):
196199
self.assertEqual(
197200
outputs["input_ids"],
198201
test_prompt_token_ids,
199-
f"input_ids 涓嶅尮閰嶏細瀹為檯{outputs['input_ids']}锛岄鏈焄{test_prompt_token_ids}]",
202+
f"input_ids mismatch: actual {outputs['input_ids']}, expected {test_prompt_token_ids}",
200203
)
201204

202205
self.assertEqual(outputs["token_type_ids"], [IDS_TYPE_FLAG["text"]] * prompt_len)
@@ -563,6 +566,8 @@ def test_prompt_token_ids2outputs_add_processed_video_token_len_mismatch(self):
563566
def test_text2ids_basic(self):
564567
"""Test text2ids with basic text input"""
565568
text = "Hello world"
569+
# Ensure encode returns proper format
570+
self.mock_tokenizer.encode.return_value = {"input_ids": [1, 2, 3]}
566571
outputs = self.data_processor.text2ids(text)
567572

568573
self.assertIn("input_ids", outputs)
@@ -608,6 +613,8 @@ def test_text2ids_with_video_placeholder(self):
608613
def test_request2ids_basic(self):
609614
"""Test request2ids with basic request"""
610615
self.data_processor.is_training = False
616+
# Fix apply_chat_template to return text without image placeholder
617+
self.mock_tokenizer.apply_chat_template.return_value = "User: Hello"
611618
request = {
612619
"messages": [{"role": "user", "content": "Hello"}],
613620
"add_generation_prompt": True,
@@ -624,6 +631,8 @@ def test_request2ids_with_multimodal(self):
624631
"""Test request2ids with multimodal content"""
625632
self.data_processor.is_training = False
626633
mock_image = Image.new("RGB", (224, 224))
634+
# Fix apply_chat_template to return text with image placeholder matching the image
635+
self.mock_tokenizer.apply_chat_template.return_value = "User: What's in this image?<|image@placeholder|>"
627636
request = {
628637
"messages": [
629638
{
@@ -672,6 +681,11 @@ def test_extract_mm_items_basic(self):
672681
]
673682
}
674683

684+
# Restore real extract_mm_items method for this test
685+
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
686+
687+
original_extract_mm_items = DataProcessor.extract_mm_items
688+
675689
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
676690
mock_parse.return_value = [
677691
{
@@ -683,6 +697,10 @@ def test_extract_mm_items_basic(self):
683697
],
684698
}
685699
]
700+
# Use real extract_mm_items method (cache is disabled, so no zmq connection needed)
701+
self.data_processor.extract_mm_items = original_extract_mm_items.__get__(
702+
self.data_processor, DataProcessor
703+
)
686704
images, videos, image_uuid, video_uuid, dealer, missing_idx, mm_items = (
687705
self.data_processor.extract_mm_items(request)
688706
)
@@ -698,8 +716,17 @@ def test_extract_mm_items_missing_data_error(self):
698716
self.data_processor.enable_processor_cache = False
699717
request = {"messages": [{"role": "user", "content": [{"type": "image", "uuid": "img1"}]}]}
700718

719+
# Restore real extract_mm_items method for this test
720+
from fastdeploy.input.ernie4_5_vl_processor.process import DataProcessor
721+
722+
original_extract_mm_items = DataProcessor.extract_mm_items
723+
701724
with patch("fastdeploy.input.ernie4_5_vl_processor.process.parse_chat_messages") as mock_parse:
702725
mock_parse.return_value = [{"role": "user", "content": [{"type": "image", "uuid": "img1"}]}]
726+
# Use real extract_mm_items method
727+
self.data_processor.extract_mm_items = original_extract_mm_items.__get__(
728+
self.data_processor, DataProcessor
729+
)
703730
with self.assertRaises(ValueError) as ctx:
704731
self.data_processor.extract_mm_items(request)
705732
self.assertIn("Missing items cannot be retrieved", str(ctx.exception))

0 commit comments

Comments
 (0)