Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,29 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
)

def schema_match(self, schema: dict) -> bool:
"""Check if a JSON schema fragment matches this transformer's python_type.

For BaseModel subclasses, automatically compares the schema's title, type, and
required fields against the type's own JSON schema. For other types, returns
False by default — override if needed.
"""
if not isinstance(schema, dict):
return False
try:
from pydantic import BaseModel

if hasattr(self.python_type, "model_json_schema") and self.python_type is not BaseModel:
this_schema = self.python_type.model_json_schema() # type: ignore[attr-defined]
return (
schema.get("title") == this_schema.get("title")
and schema.get("type") == this_schema.get("type")
and set(schema.get("required", [])) == set(this_schema.get("required", []))
)
except Exception:
pass
return False

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
"""
This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and attribute access.`
Expand Down Expand Up @@ -1209,6 +1232,9 @@ def _handle_json_schema_property(
elif property_val.get("title"):
# For nested dataclass
sub_schema_name = property_val["title"]
matched_type = _match_registered_type_from_schema(property_val)
if matched_type is not None:
return (property_key, typing.cast(GenericAlias, matched_type))
return (
property_key,
typing.cast(GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schema_name)),
Expand All @@ -1223,6 +1249,14 @@ def _handle_json_schema_property(
return (property_key, _get_element_type(property_val, schema)) # type: ignore


def _match_registered_type_from_schema(schema: dict) -> typing.Optional[type]:
"""Check if a JSON schema fragment matches any registered TypeTransformer."""
for transformer in TypeEngine._REGISTRY.values(): # type: ignore[misc]
if transformer.schema_match(schema):
return transformer.python_type
return None


def generate_attribute_list_from_dataclass_json_mixin(
schema: typing.Dict[str, typing.Any],
schema_name: typing.Any,
Expand All @@ -1243,6 +1277,11 @@ def generate_attribute_list_from_dataclass_json_mixin(
if ref_schema.get("enum"):
attribute_list.append((property_key, str))
continue
# Check if the $ref matches a registered custom type
matched_type = _match_registered_type_from_schema(ref_schema)
if matched_type is not None:
attribute_list.append((property_key, typing.cast(GenericAlias, matched_type)))
continue
# Include $defs so nested models can resolve their own $refs
if "$defs" not in ref_schema and defs:
ref_schema["$defs"] = defs
Expand Down Expand Up @@ -2553,6 +2592,9 @@ def _get_element_type(
# Guard the nested enum elements inside containers
if ref_schema.get("enum"):
return str
# Check if the $ref matches a registered custom type
if (matched_type := _match_registered_type_from_schema(ref_schema)) is not None:
return matched_type
# if defs not in the schema, they need to be propagated into the resolved schema
if "$defs" not in ref_schema and defs:
ref_schema["$defs"] = defs
Expand Down
174 changes: 174 additions & 0 deletions tests/flytekit/unit/core/test_custom_type_in_nested_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Tests for custom TypeTransformer schema_match in nested Pydantic models.

Verifies that when a custom type with a registered TypeTransformer (implementing
schema_match) is nested inside a Pydantic BaseModel, guess_python_type correctly
reconstructs the custom type instead of building a generic dataclass.
"""

import dataclasses
import typing

import pytest
from pydantic import BaseModel

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, _match_registered_type_from_schema
from flytekit.models.literals import Literal
from flytekit.models.types import LiteralType, SimpleType

# -- Custom type and transformer --


class Coordinate(BaseModel):
x: float
y: float


class CoordinateTransformer(TypeTransformer[Coordinate]):
"""A transformer for Coordinate —
Coordinate is a BaseModel and the default auto-matches."""

def __init__(self):
super().__init__("Coordinate", Coordinate)

def get_literal_type(self, t=None) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT)

def to_literal(self, ctx, python_val, python_type, expected) -> Literal:
raise NotImplementedError

def to_python_value(self, ctx, lv, expected_python_type):
raise NotImplementedError


# -- Models using Coordinate --


class ModelWithCoord(BaseModel):
label: str
coord: Coordinate


class ModelWithListOfCoords(BaseModel):
coords: typing.List[Coordinate]


class ModelWithDictOfCoords(BaseModel):
coord_map: typing.Dict[str, Coordinate]


class ModelWithOptionalCoord(BaseModel):
coord: typing.Optional[Coordinate] = None


class ModelWithNestedListOfCoords(BaseModel):
nested: typing.List[typing.List[Coordinate]]


# -- Fixtures --


@pytest.fixture(autouse=True)
def register_coordinate_transformer():
"""Register the custom transformer for each test, then clean up."""
transformer = CoordinateTransformer()
TypeEngine.register(transformer)
yield
TypeEngine._REGISTRY.pop(Coordinate, None)


# -- Unit tests for _match_registered_type_from_schema --


def test_match_returns_coordinate_for_matching_schema():
schema = Coordinate.model_json_schema()
result = _match_registered_type_from_schema(schema)
assert result is Coordinate


def test_match_returns_none_for_unmatched_schema():
schema = {"type": "object", "title": "Unknown", "properties": {"a": {"type": "string"}}, "required": ["a"]}
result = _match_registered_type_from_schema(schema)
assert result is None


# -- guess_python_type structure verification --


def test_coord_in_model_guess_type():
"""guess_python_type should reconstruct coord as Coordinate, not a generic dataclass."""
lit = TypeEngine.to_literal_type(ModelWithCoord)
guessed = TypeEngine.guess_python_type(lit)
assert dataclasses.is_dataclass(guessed)

hints = typing.get_type_hints(guessed)
assert "coord" in hints
assert hints["coord"] is Coordinate


def test_list_of_coords_guess_type():
"""guess_python_type should reconstruct List[Coordinate] with Coordinate as inner type."""
lit = TypeEngine.to_literal_type(ModelWithListOfCoords)
guessed = TypeEngine.guess_python_type(lit)
assert dataclasses.is_dataclass(guessed)

hints = typing.get_type_hints(guessed)
coords_type = hints["coords"]
assert typing.get_origin(coords_type) is list
inner = typing.get_args(coords_type)[0]
assert inner is Coordinate


def test_dict_of_coords_guess_type():
"""guess_python_type should reconstruct Dict[str, Coordinate] with Coordinate as value type."""
lit = TypeEngine.to_literal_type(ModelWithDictOfCoords)
guessed = TypeEngine.guess_python_type(lit)
assert dataclasses.is_dataclass(guessed)

hints = typing.get_type_hints(guessed)
map_type = hints["coord_map"]
assert typing.get_origin(map_type) is dict
key_type, val_type = typing.get_args(map_type)
assert key_type is str
assert val_type is Coordinate


def test_nested_list_of_coords_guess_type():
"""guess_python_type should reconstruct List[List[Coordinate]]."""
lit = TypeEngine.to_literal_type(ModelWithNestedListOfCoords)
guessed = TypeEngine.guess_python_type(lit)
assert dataclasses.is_dataclass(guessed)

hints = typing.get_type_hints(guessed)
nested_type = hints["nested"]
assert typing.get_origin(nested_type) is list

inner_list = typing.get_args(nested_type)[0]
assert typing.get_origin(inner_list) is list

innermost = typing.get_args(inner_list)[0]
assert innermost is Coordinate


# -- schema_match default behavior --


def test_base_transformer_schema_match_returns_false():
"""The default schema_match on TypeTransformer should return False."""

class DummyTransformer(TypeTransformer[str]):
def __init__(self):
super().__init__("Dummy", str)

def get_literal_type(self, t=None):
return LiteralType(simple=SimpleType.STRING)

def to_literal(self, ctx, val, typ, expected):
raise NotImplementedError

def to_python_value(self, ctx, lv, typ):
raise NotImplementedError

t = DummyTransformer()
assert t.schema_match({"type": "string"}) is False
assert t.schema_match({}) is False
Loading