From 6f629d79ba587ca97843871f3497be363bcdc6ff Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Wed, 5 Feb 2025 18:27:00 -0500 Subject: [PATCH 1/7] Minimal change to allow `attribute_keyed_dict` + test --- sqlmodel/_compat.py | 3 ++ tests/test_attribute_keyed_dict.py | 47 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 tests/test_attribute_keyed_dict.py diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4e80cdc374..a23d544f4b 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -156,6 +156,9 @@ def get_relationship_to( # If a list, then also get the real field elif origin is list: use_annotation = get_args(annotation)[0] + # If a dict, then use the value type + elif origin is dict: + use_annotation = get_args(annotation)[1] return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py new file mode 100644 index 0000000000..6dfe5ffeab --- /dev/null +++ b/tests/test_attribute_keyed_dict.py @@ -0,0 +1,47 @@ +from enum import StrEnum + +from sqlalchemy.orm.collections import attribute_keyed_dict +from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine + + +def test_attribute_keyed_dict_works(clear_sqlmodel): + class Color(StrEnum): + Orange = "Orange" + Blue = "Blue" + + class Child(SQLModel, table=True): + __tablename__ = "children" + __table_args__ = ( + Index("ix_children_parent_id_color", "parent_id", "color", unique=True), + ) + + id: int | None = Field(primary_key=True, default=None) + parent_id: int = Field(foreign_key="parents.id") + color: Color + value: int + + class Parent(SQLModel, table=True): + __tablename__ = "parents" + + id: int | None = Field(primary_key=True, default=None) + children_by_color: dict[Color, Child] = Relationship( + sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} + ) + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + parent = Parent() + session.add(parent) + session.commit() + session.refresh(parent) + session.add(Child(parent_id=parent.id, color=Color.Orange, value=1)) + session.add(Child(parent_id=parent.id, color=Color.Blue, value=2)) + session.commit() + session.refresh(parent) + assert parent.children_by_color[Color.Orange].parent_id == parent.id + assert parent.children_by_color[Color.Orange].color == Color.Orange + assert parent.children_by_color[Color.Orange].value == 1 + assert parent.children_by_color[Color.Blue].parent_id == parent.id + assert parent.children_by_color[Color.Blue].color == Color.Blue + assert parent.children_by_color[Color.Blue].value == 2 From 7f1d08587b6c90070f5e0f23d59373e7fae71d36 Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Thu, 6 Feb 2025 17:11:22 -0500 Subject: [PATCH 2/7] Remove `StrEnum` --- tests/test_attribute_keyed_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index 6dfe5ffeab..5e8c61ba28 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,11 +1,11 @@ -from enum import StrEnum +from enum import Enum from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine def test_attribute_keyed_dict_works(clear_sqlmodel): - class Color(StrEnum): + class Color(str, Enum): Orange = "Orange" Blue = "Blue" From ad7f6bbc1a96ca48fa50307ba4c49b99b0392d32 Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Thu, 6 Feb 2025 17:30:45 -0500 Subject: [PATCH 3/7] Remove type union pipe syntax --- tests/test_attribute_keyed_dict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index 5e8c61ba28..a55a927f4a 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine @@ -15,7 +16,7 @@ class Child(SQLModel, table=True): Index("ix_children_parent_id_color", "parent_id", "color", unique=True), ) - id: int | None = Field(primary_key=True, default=None) + id: Optional[int] = Field(primary_key=True, default=None) parent_id: int = Field(foreign_key="parents.id") color: Color value: int @@ -23,7 +24,7 @@ class Child(SQLModel, table=True): class Parent(SQLModel, table=True): __tablename__ = "parents" - id: int | None = Field(primary_key=True, default=None) + id: Optional[int] = Field(primary_key=True, default=None) children_by_color: dict[Color, Child] = Relationship( sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} ) From 563886a03e7e5e87af95a2f7ee4c079f3127ed3d Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Thu, 6 Feb 2025 19:02:28 -0500 Subject: [PATCH 4/7] Use `Dict[]` instead of `dict[]` --- tests/test_attribute_keyed_dict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index a55a927f4a..9d06196396 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Dict, Optional from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine @@ -25,7 +25,7 @@ class Parent(SQLModel, table=True): __tablename__ = "parents" id: Optional[int] = Field(primary_key=True, default=None) - children_by_color: dict[Color, Child] = Relationship( + children_by_color: Dict[Color, Child] = Relationship( sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")} ) From fb43fd49987d8237742ef97474ff4d9ef1ab70aa Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Wed, 1 Oct 2025 11:25:48 -0400 Subject: [PATCH 5/7] Remove superfluous index in test case Co-authored-by: Motov Yurii <109919500+YuriiMotov@users.noreply.github.com> --- tests/test_attribute_keyed_dict.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index 9d06196396..cbd8847e2a 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -12,9 +12,6 @@ class Color(str, Enum): class Child(SQLModel, table=True): __tablename__ = "children" - __table_args__ = ( - Index("ix_children_parent_id_color", "parent_id", "color", unique=True), - ) id: Optional[int] = Field(primary_key=True, default=None) parent_id: int = Field(foreign_key="parents.id") From 8a693f022cf44c688b18727b6ee75b75783edbce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:25:56 +0000 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_attribute_keyed_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index cbd8847e2a..35968592bd 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -2,7 +2,7 @@ from typing import Dict, Optional from sqlalchemy.orm.collections import attribute_keyed_dict -from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine def test_attribute_keyed_dict_works(clear_sqlmodel): From ac996bfb6c55a42c43081ce9a38fc591612f9e3f Mon Sep 17 00:00:00 2001 From: Natarajan Krishnaswami Date: Wed, 29 Oct 2025 16:39:51 -0400 Subject: [PATCH 7/7] Make dict get_relationship_to more robust --- sqlmodel/_compat.py | 13 +++++++-- tests/test_attribute_keyed_dict.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 9797dbf257..cade3242e7 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -164,9 +164,16 @@ def get_relationship_to( # If a list, then also get the real field elif origin is list: use_annotation = get_args(annotation)[0] - # If a dict, then use the value type - elif origin is dict: - use_annotation = get_args(annotation)[1] + # If a dict or Mapping, then use the value (second) type argument + elif origin is dict or origin is Mapping: + args = get_args(annotation) + if len(args) >= 2: + use_annotation = args[1] + else: + raise ValueError( + f"Dict/Mapping relationship field '{name}' must have both " + "key and value type arguments (e.g., dict[str, Model])" + ) return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation diff --git a/tests/test_attribute_keyed_dict.py b/tests/test_attribute_keyed_dict.py index 35968592bd..6aaf38b6eb 100644 --- a/tests/test_attribute_keyed_dict.py +++ b/tests/test_attribute_keyed_dict.py @@ -1,6 +1,8 @@ +import re from enum import Enum from typing import Dict, Optional +import pytest from sqlalchemy.orm.collections import attribute_keyed_dict from sqlmodel import Field, Relationship, Session, SQLModel, create_engine @@ -43,3 +45,46 @@ class Parent(SQLModel, table=True): assert parent.children_by_color[Color.Blue].parent_id == parent.id assert parent.children_by_color[Color.Blue].color == Color.Blue assert parent.children_by_color[Color.Blue].value == 2 + + +def test_dict_relationship_throws_on_missing_annotation_arg(clear_sqlmodel): + class Color(str, Enum): + Orange = "Orange" + Blue = "Blue" + + class Child(SQLModel, table=True): + __tablename__ = "children" + + id: Optional[int] = Field(primary_key=True, default=None) + parent_id: int = Field(foreign_key="parents.id") + color: Color + value: int + + error_msg_re = re.escape( + "Dict/Mapping relationship field 'children_by_color' must have both key and value type arguments (e.g., dict[str, Model])" + ) + # No type args + with pytest.raises(ValueError, match=error_msg_re): + + class Parent(SQLModel, table=True): + __tablename__ = "parents" + + id: Optional[int] = Field(primary_key=True, default=None) + children_by_color: dict[()] = Relationship( + sa_relationship_kwargs={ + "collection_class": attribute_keyed_dict("color") + } + ) + + # One type arg + with pytest.raises(ValueError, match=error_msg_re): + + class Parent(SQLModel, table=True): + __tablename__ = "parents" + + id: Optional[int] = Field(primary_key=True, default=None) + children_by_color: dict[Color] = Relationship( + sa_relationship_kwargs={ + "collection_class": attribute_keyed_dict("color") + } + )