diff --git a/Makefile b/Makefile index ee3f65f..e7e54c7 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ SYNC_NAME := redis_om INSTALL_STAMP := .install.stamp UV := $(shell command -v uv 2> /dev/null) REDIS_OM_URL ?= redis://localhost:6380?decode_responses=True +DOCKER_COMPOSE := docker compose +CLUSTER_COMPOSE := $(DOCKER_COMPOSE) -f docker-compose.cluster.yml .DEFAULT_GOAL := help @@ -40,7 +42,8 @@ clean: rm -rf redis_om rm -rf tests_sync rm -rf .venv - -docker-compose down + -$(DOCKER_COMPOSE) down + -$(CLUSTER_COMPOSE) down .PHONY: dist @@ -68,7 +71,7 @@ format: $(INSTALL_STAMP) sync .PHONY: test test: $(INSTALL_STAMP) sync redis REDIS_OM_URL=$(REDIS_OM_URL) $(UV) run pytest -n auto -vv ./tests/ ./tests_sync/ --cov-report term-missing --cov $(NAME) $(SYNC_NAME) - docker-compose down + $(DOCKER_COMPOSE) down .PHONY: test_oss test_oss: $(INSTALL_STAMP) sync redis @@ -84,7 +87,37 @@ shell: $(INSTALL_STAMP) .PHONY: redis redis: - docker-compose up -d + $(DOCKER_COMPOSE) up -d + +.PHONY: redis_cluster +redis_cluster: + $(CLUSTER_COMPOSE) up -d + @echo "Waiting for Redis Cluster nodes to start..." + @sleep 5 + @cluster_init_container=$$($(CLUSTER_COMPOSE) ps -q redis-cluster-7001); \ + if ! docker exec $$cluster_init_container redis-cli -p 7001 cluster info 2>/dev/null | grep -q "cluster_state:ok"; then \ + docker exec $$cluster_init_container redis-cli --cluster create \ + 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 \ + 127.0.0.1:7004 127.0.0.1:7005 127.0.0.1:7006 \ + --cluster-replicas 1 --cluster-yes; \ + fi + @echo "Waiting for Redis Cluster to become healthy..." + @cluster_init_container=$$($(CLUSTER_COMPOSE) ps -q redis-cluster-7001); \ + for attempt in 1 2 3 4 5 6 7 8 9 10; do \ + if docker exec $$cluster_init_container redis-cli -p 7001 cluster info 2>/dev/null | grep -q "cluster_state:ok"; then \ + exit 0; \ + fi; \ + sleep 2; \ + done; \ + echo "Redis Cluster did not become healthy in time" >&2; \ + exit 1 + +.PHONY: test_cluster +test_cluster: $(INSTALL_STAMP) sync redis redis_cluster + REDIS_OM_URL=$(REDIS_OM_URL) $(UV) run pytest -vv ./tests/test_cluster_operations.py + REDIS_OM_URL=$(REDIS_OM_URL) $(UV) run pytest -vv ./tests_sync/test_cluster_operations.py + $(CLUSTER_COMPOSE) down + $(DOCKER_COMPOSE) down .PHONY: upload upload: dist diff --git a/README.md b/README.md index 551d963..4e0a9f2 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,19 @@ Cluster support includes: - RediSearch-backed queries, including embedded JSON and GEO lookups - migrator support for creating search indexes on cluster deployments +For local cluster validation, this repository includes a dedicated 6-node Redis +Cluster test setup on ports `7001-7006`: + +```sh +make redis +make redis_cluster +make test_cluster +``` + +`get_redis_connection()` accepts either `cluster=True` or `cluster=true` in the +URL and strips the query flag before handing the URL to redis-py, so other URL +parameters such as `decode_responses=True` continue to work unchanged. + ## 📇 Modeling Your Data Redis OM contains powerful declarative models that give you data validation, serialization, and persistence to Redis. @@ -798,6 +811,10 @@ We'd love your contributions! You can also **contribute documentation** -- or just let us know if something needs more detail. [Open an issue on GitHub](https://github.com/XChikuX/redis-om-python/issues/new) to get started. +Current local coverage baseline: **88% overall** across `aredis_om/` and the +generated `redis_om/` mirror, with **1168 passing tests** plus the cluster test +suite. + ## 📝 License Redis OM uses the [MIT license][license-url]. diff --git a/aredis_om/connections.py b/aredis_om/connections.py index 1c432af..46a9726 100644 --- a/aredis_om/connections.py +++ b/aredis_om/connections.py @@ -31,10 +31,13 @@ def get_redis_connection(**kwargs) -> Union[redis.Redis, redis.RedisCluster]: def _strip_cluster_param(url: str) -> str: """Remove 'cluster=true' from URL query parameters.""" - from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse parsed = urlparse(url) - params = parse_qs(parsed.query, keep_blank_values=True) - params.pop("cluster", None) + params = [ + (key, value) + for key, value in parse_qsl(parsed.query, keep_blank_values=True) + if key.lower() != "cluster" + ] new_query = urlencode(params, doseq=True) return urlunparse(parsed._replace(query=new_query)) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index e82b1e1..5e3d160 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -29,7 +29,9 @@ Union, ) from typing import get_args as typing_get_args -from typing import no_type_check +from typing import ( + no_type_check, +) from more_itertools import ichunked from pydantic import ConfigDict @@ -182,7 +184,19 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any obj = model for sub_field in field_name.split("__"): if not isinstance(obj, ModelMeta) and hasattr(obj, "field"): - obj = getattr(obj, "field").annotation + annotation = getattr(obj, "field").annotation + # Unwrap Optional[X] (i.e. Union[X, None]) so that we can + # traverse into the inner model's fields. + if get_origin(annotation) is Union: + annotation = next( + ( + a + for a in typing_get_args(annotation) + if a is not type(None) + ), + annotation, + ) + obj = annotation if not hasattr(obj, sub_field): raise QuerySyntaxError( @@ -3054,21 +3068,23 @@ def schema_for_type( "In this Preview release, list and tuple fields can only " f"contain strings. Problem field: {name}. See docs: TODO" ) - if full_text_search is True: - raise RedisModelError( - "List and tuple fields cannot be indexed for full-text " - f"search. Problem field: {name}. See docs: TODO" - ) if sortable is True: raise RedisModelError( "In this Preview release, list and tuple fields cannot be " f"marked as sortable. Problem field: {name}. See docs: TODO" ) + if case_sensitive is True and full_text_search is True: + raise RedisModelError( + f"List field '{name}' cannot be both case-sensitive and " + "full-text searchable." + ) separator = getattr( field_info, "separator", SINGLE_VALUE_TAG_FIELD_SEPARATOR ) schema = f"{path} AS {index_field_name} TAG SEPARATOR {separator}" - if case_sensitive is True: + if full_text_search is True: + schema += f" {path} AS {index_field_name}_fts TEXT" + elif case_sensitive is True: schema += " CASESENSITIVE" elif typ is bool: schema = f"{path} AS {index_field_name} TAG" diff --git a/docker-compose.cluster.yml b/docker-compose.cluster.yml new file mode 100644 index 0000000..2f283a5 --- /dev/null +++ b/docker-compose.cluster.yml @@ -0,0 +1,108 @@ +services: + redis-cluster-7001: + image: "redis:8-alpine" + restart: always + network_mode: host + command: > + redis-server + --port 7001 + --dir /data + --save "" + --appendonly no + --protected-mode no + --cluster-enabled yes + --cluster-config-file nodes-7001.conf + --cluster-node-timeout 5000 + --cluster-announce-ip 127.0.0.1 + --cluster-announce-port 7001 + --cluster-announce-bus-port 17001 + + redis-cluster-7002: + image: "redis:8-alpine" + restart: always + network_mode: host + command: > + redis-server + --port 7002 + --dir /data + --save "" + --appendonly no + --protected-mode no + --cluster-enabled yes + --cluster-config-file nodes-7002.conf + --cluster-node-timeout 5000 + --cluster-announce-ip 127.0.0.1 + --cluster-announce-port 7002 + --cluster-announce-bus-port 17002 + + redis-cluster-7003: + image: "redis:8-alpine" + restart: always + network_mode: host + command: > + redis-server + --port 7003 + --dir /data + --save "" + --appendonly no + --protected-mode no + --cluster-enabled yes + --cluster-config-file nodes-7003.conf + --cluster-node-timeout 5000 + --cluster-announce-ip 127.0.0.1 + --cluster-announce-port 7003 + --cluster-announce-bus-port 17003 + + redis-cluster-7004: + image: "redis:8-alpine" + restart: always + network_mode: host + command: > + redis-server + --port 7004 + --dir /data + --save "" + --appendonly no + --protected-mode no + --cluster-enabled yes + --cluster-config-file nodes-7004.conf + --cluster-node-timeout 5000 + --cluster-announce-ip 127.0.0.1 + --cluster-announce-port 7004 + --cluster-announce-bus-port 17004 + + redis-cluster-7005: + image: "redis:8-alpine" + restart: always + network_mode: host + command: > + redis-server + --port 7005 + --dir /data + --save "" + --appendonly no + --protected-mode no + --cluster-enabled yes + --cluster-config-file nodes-7005.conf + --cluster-node-timeout 5000 + --cluster-announce-ip 127.0.0.1 + --cluster-announce-port 7005 + --cluster-announce-bus-port 17005 + + redis-cluster-7006: + image: "redis:8-alpine" + restart: always + network_mode: host + command: > + redis-server + --port 7006 + --dir /data + --save "" + --appendonly no + --protected-mode no + --cluster-enabled yes + --cluster-config-file nodes-7006.conf + --cluster-node-timeout 5000 + --cluster-announce-ip 127.0.0.1 + --cluster-announce-port 7006 + --cluster-announce-bus-port 17006 diff --git a/make_sync.py b/make_sync.py index 6aa66d2..89f47f3 100644 --- a/make_sync.py +++ b/make_sync.py @@ -6,15 +6,45 @@ ADDITIONAL_REPLACEMENTS = { "aredis_om": "redis_om", "async_redis": "sync_redis", + "redis.asyncio as aioredis": "redis as aioredis", ":tests.": ":tests_sync.", "pytest_asyncio": "pytest", "py_test_mark_asyncio": "py_test_mark_sync", "pytest.mark.asyncio(f)": "f", "pytest.mark.asyncio": "py_test_mark_sync", + ".aclose()": ".close()", } +POST_SYNC_FIXES = { + "tests_sync/test_cluster_operations.py": { + "import redis.asyncio as aioredis": "import redis as aioredis", + "conn.aclose()": "conn.close()", + # In the generated sync mirror these call sites already contain eager + # return values, not coroutines, so the async gather wrapper must be + # removed. + "asyncio.gather(*tasks)": "tasks", + } +} + + +def apply_post_sync_fixes(repo_root: Path): + for relative_path, replacements in POST_SYNC_FIXES.items(): + file_path = repo_root / relative_path + if not file_path.exists(): + continue + + content = file_path.read_text() + updated = content + for old, new in replacements.items(): + updated = updated.replace(old, new) + + if updated != content: + file_path.write_text(updated) + + def main(): + repo_root = Path(__file__).absolute().parent rules = [ unasync.Rule( fromdir="/aredis_om/", @@ -28,7 +58,7 @@ def main(): ), ] filepaths = [] - for root, _, filenames in os.walk(Path(__file__).absolute().parent): + for root, _, filenames in os.walk(repo_root): for filename in filenames: if filename.rpartition(".")[-1] in ( "py", @@ -37,6 +67,7 @@ def main(): filepaths.append(os.path.join(root, filename)) unasync.unasync_files(filepaths, rules) + apply_post_sync_fixes(repo_root) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 9153d63..a84a344 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pyredis-om" -version = "0.6.0" +version = "0.6.1" description = "A drop-in replacement for `redis-om`, built out of frustration." authors = [ { name = "Redis OSS", email = "oss@redis.com" }, diff --git a/tests/test_cluster_operations.py b/tests/test_cluster_operations.py index b76f1c3..5691058 100644 --- a/tests/test_cluster_operations.py +++ b/tests/test_cluster_operations.py @@ -12,7 +12,7 @@ - GEO operations (save, search, GeoFilter) on cluster - Full-text search on cluster - Complex queries (AND, OR, NOT, IN, range) on cluster -- Index creation / migration on cluster (FT.CREATE via target_nodes=PRIMARIES) +- Index creation / migration on cluster (FT.CREATE via target_nodes=RANDOM) - Pipeline and batch operations on cluster - Performance comparison vs single-instance (pass/fail based on slowdown factor) - Direct Redis verification before redis-om layer queries @@ -1421,7 +1421,7 @@ async def test_cluster_pipeline_mixed_ops(cluster_json_models, cluster_hash_mode @py_test_mark_asyncio async def test_cluster_migration_creates_indexes(cluster_conn): - """Cluster: Migrator creates indexes on cluster primaries.""" + """Cluster: Migrator creates indexes on cluster.""" model_registry.clear() class MigrTestJson(JsonModel): diff --git a/tests/test_connections.py b/tests/test_connections.py index eab8024..b5ed830 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -6,7 +6,8 @@ import pytest -from aredis_om.connections import get_redis_connection +from aredis_om import connections as connections_module +from aredis_om.connections import _strip_cluster_param, get_redis_connection class TestGetRedisConnection: @@ -42,6 +43,48 @@ def test_cluster_mode_from_param(self): except Exception: pass # expected since there's no cluster + def test_strip_cluster_param_preserves_other_query_params(self): + clean = _strip_cluster_param( + "redis://localhost:7001/0?decode_responses=True&cluster=true&protocol=3" + ) + + assert clean == "redis://localhost:7001/0?decode_responses=True&protocol=3" + + def test_strip_cluster_param_removes_case_insensitive_cluster_query_key(self): + clean = _strip_cluster_param( + "redis://localhost:7001/0?decode_responses=True&Cluster=True" + ) + + assert clean == "redis://localhost:7001/0?decode_responses=True" + + def test_strip_cluster_param_removes_duplicate_case_variants(self): + clean = _strip_cluster_param( + "redis://localhost:7001/0?cluster=true&Cluster=false&decode_responses=True" + ) + + assert clean == "redis://localhost:7001/0?decode_responses=True" + + def test_get_redis_connection_strips_cluster_query_before_from_url( + self, monkeypatch + ): + sentinel = object() + calls = {} + + def fake_from_url(url, **kwargs): + calls["url"] = url + calls["kwargs"] = kwargs + return sentinel + + monkeypatch.setattr(connections_module.redis.RedisCluster, "from_url", fake_from_url) + + conn = get_redis_connection( + url="redis://localhost:7001/0?decode_responses=True&Cluster=True&protocol=3" + ) + + assert conn is sentinel + assert calls["url"] == "redis://localhost:7001/0?decode_responses=True&protocol=3" + assert calls["kwargs"]["decode_responses"] is True + def test_no_url_no_env_uses_defaults(self, monkeypatch): monkeypatch.delenv("REDIS_OM_URL", raising=False) conn = get_redis_connection() diff --git a/tests/test_embedded_structures.py b/tests/test_embedded_structures.py new file mode 100644 index 0000000..d0ce81a --- /dev/null +++ b/tests/test_embedded_structures.py @@ -0,0 +1,1103 @@ +# type: ignore +""" +Tests for complex embedded structures in JsonModel/HashModel. + +Covers: +- Multiple EmbeddedJsonModel fields at the same level +- Deep (4-level) embedding chains +- Optional embedded models (None and set) +- List of EmbeddedJsonModel with indexed string fields +- Embedded model containing a List[str] with full_text_search +- OR / AND / NOT queries spanning multiple embedded levels +- Update operations on deeply-nested fields via __ notation +- get_many on models with embedded structures +- HashModel coexisting and interoperating with JsonModel (pk reference) +- Completely optional embedded models (all fields None) +- Embedded model with GEO coordinates inside a list +- Pipeline save with embedded models +- Empty / single-element list of embedded models +- Nested embedded re-querying after update +""" + +import abc +import collections +from typing import List, Optional + +import pytest +import pytest_asyncio + +from aredis_om import ( + Coordinates, + EmbeddedJsonModel, + Field, + HashModel, + JsonModel, + Migrator, + NotFoundError, + RedisModelError, +) +from aredis_om.model.model import model_registry +from tests._sync_redis import has_redis_json, has_redisearch + +from .conftest import py_test_mark_asyncio + +if not has_redis_json(): + pytestmark = pytest.mark.skip + + +# --------------------------------------------------------------------------- +# Shared fixture: base class + all embedded/top-level model definitions +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def em(key_prefix, redis): + """Return a namespace of all models used in this module.""" + model_registry.clear() + + # ── Base classes ────────────────────────────────────────────────────── + class BaseJson(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + class BaseHash(HashModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + # ── Level-4 (deepest) embedded ──────────────────────────────────────── + class ContactDetails(EmbeddedJsonModel): + """Deepest leaf: phone + email.""" + + phone: str = Field(index=True) + email: str = Field(index=True) + + # ── Level-3 embedded ───────────────────────────────────────────────── + class Office(EmbeddedJsonModel): + """Contains level-4 contact details.""" + + building: str = Field(index=True) + floor: int + contact: ContactDetails + + # ── Level-2 embedded ───────────────────────────────────────────────── + class Department(EmbeddedJsonModel): + """Contains level-3 office.""" + + name: str = Field(index=True) + budget: float + office: Office + + # ── Top-level JsonModel with deep embedding ─────────────────────────── + class Company(BaseJson): + """4-level deep: Company → Department → Office → ContactDetails.""" + + name: str = Field(index=True) + industry: str = Field(index=True) + department: Department + + # ── Multiple embedded models at the SAME level ──────────────────────── + class Skills(EmbeddedJsonModel): + primary: str = Field(index=True) + secondary: Optional[str] = Field(index=True, default=None) + + class Employment(EmbeddedJsonModel): + employer: str = Field(index=True) + role: str = Field(index=True) + years: int + + class Education(EmbeddedJsonModel): + institution: str = Field(index=True) + degree: str = Field(index=True) + graduated: int # year + + class Profile(BaseJson): + """Three EmbeddedJsonModel fields at the same level.""" + + username: str = Field(index=True) + bio: Optional[str] = Field(index=True, full_text_search=True, default="") + skills: Skills + employment: Employment + education: Education + + # ── Optional embedded model ─────────────────────────────────────────── + class Address(EmbeddedJsonModel): + city: str = Field(index=True) + country: str = Field(index=True) + + class Person(BaseJson): + name: str = Field(index=True) + address: Optional[Address] = None + + # ── List[EmbeddedJsonModel] with indexed string field ───────────────── + class Tag(EmbeddedJsonModel): + label: str = Field(index=True) + + class Article(BaseJson): + title: str = Field(index=True) + tags: List[Tag] + + # ── EmbeddedJsonModel containing List[str] with full_text_search ────── + class Section(EmbeddedJsonModel): + heading: str = Field(index=True) + keywords: List[str] = Field(index=True, full_text_search=True) + + class Document(BaseJson): + doc_title: str = Field(index=True) + section: Section + + # ── HashModel whose pk is referenced by a JsonModel ─────────────────── + class Tenant(BaseHash): + tenant_name: str = Field(index=True) + + class Subscription(BaseJson): + """Stores the HashModel pk as a plain string field.""" + + tenant_pk: str = Field(index=True) + plan: str = Field(index=True) + + # ── All-optional embedded model ─────────────────────────────────────── + class Metadata(EmbeddedJsonModel): + note: Optional[str] = Field(index=True, default=None) + score: Optional[float] = None + + class Widget(BaseJson): + widget_name: str = Field(index=True) + meta: Optional[Metadata] = None + + # ── EmbeddedJsonModel with Coordinates (GEO) ───────────────────────── + class Venue(EmbeddedJsonModel): + venue_name: str = Field(index=True) + location: Optional[Coordinates] = Field(index=True, default=None) + + class Event(BaseJson): + event_name: str = Field(index=True) + venue: Venue + + # ── List[EmbeddedJsonModel] where items have string + GEO fields ────── + class Stop(EmbeddedJsonModel): + city: str = Field(index=True) + + class Route(BaseJson): + route_name: str = Field(index=True) + stops: List[Stop] + + await Migrator().run() + + Models = collections.namedtuple( + "Models", + [ + "BaseJson", + "BaseHash", + "ContactDetails", + "Office", + "Department", + "Company", + "Skills", + "Employment", + "Education", + "Profile", + "Address", + "Person", + "Tag", + "Article", + "Section", + "Document", + "Tenant", + "Subscription", + "Metadata", + "Widget", + "Venue", + "Event", + "Stop", + "Route", + ], + ) + return Models( + BaseJson, + BaseHash, + ContactDetails, + Office, + Department, + Company, + Skills, + Employment, + Education, + Profile, + Address, + Person, + Tag, + Article, + Section, + Document, + Tenant, + Subscription, + Metadata, + Widget, + Venue, + Event, + Stop, + Route, + ) + + +# =========================================================================== +# 1. Deep (4-level) embedding – save / retrieve / query +# =========================================================================== + + +@py_test_mark_asyncio +async def test_deep_embed_save_and_retrieve(em): + contact = em.ContactDetails(phone="555-1234", email="ops@corp.com") + office = em.Office(building="HQ", floor=3, contact=contact) + dept = em.Department(name="Engineering", budget=1_000_000.0, office=office) + company = em.Company(name="Acme", industry="Technology", department=dept) + await company.save() + + fetched = await em.Company.get(company.pk) + assert fetched.name == "Acme" + assert fetched.department.name == "Engineering" + assert fetched.department.office.building == "HQ" + assert fetched.department.office.floor == 3 + assert fetched.department.office.contact.phone == "555-1234" + assert fetched.department.office.contact.email == "ops@corp.com" + + +@py_test_mark_asyncio +async def test_deep_embed_query_on_top_level_fields(em): + c1 = em.Company( + name="Alpha Inc", + industry="Finance", + department=em.Department( + name="Accounting", + budget=500_000.0, + office=em.Office( + building="Tower A", + floor=1, + contact=em.ContactDetails(phone="111", email="a@a.com"), + ), + ), + ) + c2 = em.Company( + name="Beta Corp", + industry="Technology", + department=em.Department( + name="R&D", + budget=2_000_000.0, + office=em.Office( + building="Lab", + floor=5, + contact=em.ContactDetails(phone="222", email="b@b.com"), + ), + ), + ) + await c1.save() + await c2.save() + + results = await em.Company.find(em.Company.industry == "Technology").all() + pks = {r.pk for r in results} + assert c2.pk in pks + assert c1.pk not in pks + + +@py_test_mark_asyncio +async def test_deep_embed_query_on_level2_field(em): + c1 = em.Company( + name="Zeta", + industry="Health", + department=em.Department( + name="Billing", + budget=300_000.0, + office=em.Office( + building="MedTower", + floor=2, + contact=em.ContactDetails(phone="999", email="z@z.com"), + ), + ), + ) + await c1.save() + + results = await em.Company.find(em.Company.department.name == "Billing").all() + pks = {r.pk for r in results} + assert c1.pk in pks + + +@py_test_mark_asyncio +async def test_deep_embed_query_on_level3_field(em): + c1 = em.Company( + name="Sigma", + industry="Retail", + department=em.Department( + name="Logistics", + budget=750_000.0, + office=em.Office( + building="Warehouse", + floor=1, + contact=em.ContactDetails(phone="777", email="s@s.com"), + ), + ), + ) + await c1.save() + + results = await em.Company.find( + em.Company.department.office.building == "Warehouse" + ).all() + pks = {r.pk for r in results} + assert c1.pk in pks + + +@py_test_mark_asyncio +async def test_deep_embed_query_on_level4_field(em): + c1 = em.Company( + name="Tau", + industry="Media", + department=em.Department( + name="Editorial", + budget=200_000.0, + office=em.Office( + building="Press House", + floor=4, + contact=em.ContactDetails(phone="333-UNIQUE", email="tau@tau.com"), + ), + ), + ) + await c1.save() + + results = await em.Company.find( + em.Company.department.office.contact.phone == "333-UNIQUE" + ).all() + pks = {r.pk for r in results} + assert c1.pk in pks + + +# =========================================================================== +# 2. Multiple EmbeddedJsonModel fields at the same level +# =========================================================================== + + +@py_test_mark_asyncio +async def test_multiple_embedded_same_level_save_retrieve(em): + skills = em.Skills(primary="Python", secondary="Rust") + employment = em.Employment(employer="OpenAI", role="Engineer", years=3) + education = em.Education(institution="MIT", degree="BSc", graduated=2020) + + profile = em.Profile( + username="dev42", + bio="Backend developer passionate about distributed systems", + skills=skills, + employment=employment, + education=education, + ) + await profile.save() + + fetched = await em.Profile.get(profile.pk) + assert fetched.skills.primary == "Python" + assert fetched.skills.secondary == "Rust" + assert fetched.employment.employer == "OpenAI" + assert fetched.employment.role == "Engineer" + assert fetched.education.institution == "MIT" + assert fetched.education.degree == "BSc" + + +@py_test_mark_asyncio +async def test_multiple_embedded_same_level_query_skills(em): + p1 = em.Profile( + username="rustacean", + skills=em.Skills(primary="Rust"), + employment=em.Employment(employer="Mozilla", role="SWE", years=2), + education=em.Education(institution="Stanford", degree="MSc", graduated=2019), + ) + p2 = em.Profile( + username="pythonista", + skills=em.Skills(primary="Python"), + employment=em.Employment(employer="Google", role="SWE", years=5), + education=em.Education(institution="Caltech", degree="PhD", graduated=2021), + ) + await p1.save() + await p2.save() + + results = await em.Profile.find(em.Profile.skills.primary == "Rust").all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk not in pks + + +@py_test_mark_asyncio +async def test_multiple_embedded_same_level_query_employment(em): + p1 = em.Profile( + username="googler", + skills=em.Skills(primary="Go"), + employment=em.Employment(employer="Google", role="SRE", years=4), + education=em.Education(institution="CMU", degree="BSc", graduated=2018), + ) + p2 = em.Profile( + username="amazonian", + skills=em.Skills(primary="Java"), + employment=em.Employment(employer="Amazon", role="DevOps", years=6), + education=em.Education(institution="UW", degree="BSc", graduated=2016), + ) + await p1.save() + await p2.save() + + results = await em.Profile.find(em.Profile.employment.employer == "Amazon").all() + pks = {r.pk for r in results} + assert p2.pk in pks + assert p1.pk not in pks + + +@py_test_mark_asyncio +async def test_multiple_embedded_same_level_query_education(em): + p1 = em.Profile( + username="oxbridge", + skills=em.Skills(primary="C++"), + employment=em.Employment(employer="ARM", role="HW", years=7), + education=em.Education(institution="Oxford", degree="MEng", graduated=2017), + ) + await p1.save() + + results = await em.Profile.find(em.Profile.education.institution == "Oxford").all() + pks = {r.pk for r in results} + assert p1.pk in pks + + +@py_test_mark_asyncio +async def test_or_query_spanning_two_embedded_models(em): + """OR across two different EmbeddedJsonModel fields must use correct prefixes.""" + p1 = em.Profile( + username="player1", + skills=em.Skills(primary="Kotlin"), + employment=em.Employment(employer="JetBrains", role="Dev", years=1), + education=em.Education(institution="HSE", degree="BSc", graduated=2023), + ) + p2 = em.Profile( + username="player2", + skills=em.Skills(primary="Swift"), + employment=em.Employment(employer="Apple", role="iOS", years=2), + education=em.Education(institution="UIUC", degree="BSc", graduated=2022), + ) + await p1.save() + await p2.save() + + # Query: primary skill is Kotlin OR employer is Apple + results = await em.Profile.find( + (em.Profile.skills.primary == "Kotlin") + | (em.Profile.employment.employer == "Apple") + ).all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk in pks + + +@py_test_mark_asyncio +async def test_and_query_spanning_two_embedded_models(em): + p1 = em.Profile( + username="combo1", + skills=em.Skills(primary="Scala"), + employment=em.Employment(employer="Databricks", role="Eng", years=3), + education=em.Education(institution="Berkeley", degree="MSc", graduated=2020), + ) + p2 = em.Profile( + username="combo2", + skills=em.Skills(primary="Scala"), + employment=em.Employment(employer="Snowflake", role="Eng", years=2), + education=em.Education(institution="Yale", degree="BSc", graduated=2021), + ) + await p1.save() + await p2.save() + + # Both have Scala as primary but only p1 is at Databricks + results = await em.Profile.find( + (em.Profile.skills.primary == "Scala") + & (em.Profile.employment.employer == "Databricks") + ).all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk not in pks + + +@py_test_mark_asyncio +async def test_and_query_spanning_parent_and_embedded(em): + p1 = em.Profile( + username="special_user", + skills=em.Skills(primary="TypeScript"), + employment=em.Employment(employer="Vercel", role="FE", years=2), + education=em.Education(institution="NYU", degree="BSc", graduated=2022), + ) + p2 = em.Profile( + username="other_user", + skills=em.Skills(primary="TypeScript"), + employment=em.Employment(employer="Netlify", role="FE", years=1), + education=em.Education(institution="NYU", degree="BSc", graduated=2022), + ) + await p1.save() + await p2.save() + + # username == "special_user" AND skills.primary == "TypeScript" + results = await em.Profile.find( + (em.Profile.username == "special_user") + & (em.Profile.skills.primary == "TypeScript") + ).all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk not in pks + + +@py_test_mark_asyncio +async def test_not_query_on_embedded_field(em): + p1 = em.Profile( + username="not_apple", + skills=em.Skills(primary="Elixir"), + employment=em.Employment(employer="Discord", role="BE", years=4), + education=em.Education(institution="RIT", degree="BSc", graduated=2019), + ) + p2 = em.Profile( + username="apple_employee", + skills=em.Skills(primary="Swift"), + employment=em.Employment(employer="Apple", role="iOS", years=3), + education=em.Education(institution="UW", degree="BSc", graduated=2018), + ) + await p1.save() + await p2.save() + + results = await em.Profile.find(~(em.Profile.employment.employer == "Apple")).all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk not in pks + + +# =========================================================================== +# 3. Optional embedded model (None vs. set) +# =========================================================================== + + +@py_test_mark_asyncio +async def test_optional_embedded_model_none(em): + person = em.Person(name="NoAddress") + await person.save() + + fetched = await em.Person.get(person.pk) + assert fetched.name == "NoAddress" + assert fetched.address is None + + +@py_test_mark_asyncio +async def test_optional_embedded_model_set(em): + person = em.Person( + name="WithAddress", + address=em.Address(city="Berlin", country="Germany"), + ) + await person.save() + + fetched = await em.Person.get(person.pk) + assert fetched.address is not None + assert fetched.address.city == "Berlin" + assert fetched.address.country == "Germany" + + +@py_test_mark_asyncio +async def test_optional_embedded_model_query_on_city(em): + p1 = em.Person(name="Berliner", address=em.Address(city="Berlin", country="DE")) + p2 = em.Person(name="Parisian", address=em.Address(city="Paris", country="FR")) + p3 = em.Person(name="NoCity") + await p1.save() + await p2.save() + await p3.save() + + results = await em.Person.find(em.Person.address.city == "Berlin").all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk not in pks + + +@py_test_mark_asyncio +async def test_optional_embedded_updated_from_none_to_set(em): + """Set an optional embedded model that was initially None.""" + person = em.Person(name="Changeable") + await person.save() + assert person.address is None + + await person.update(address=em.Address(city="Rome", country="IT")) + fetched = await em.Person.get(person.pk) + assert fetched.address is not None + assert fetched.address.city == "Rome" + + +# =========================================================================== +# 4. List of EmbeddedJsonModel +# =========================================================================== + + +@py_test_mark_asyncio +async def test_list_of_embedded_save_retrieve(em): + article = em.Article( + title="Redis OM Deep Dive", + tags=[em.Tag(label="redis"), em.Tag(label="python"), em.Tag(label="orm")], + ) + await article.save() + + fetched = await em.Article.get(article.pk) + assert fetched.title == "Redis OM Deep Dive" + assert len(fetched.tags) == 3 + labels = {t.label for t in fetched.tags} + assert labels == {"redis", "python", "orm"} + + +@py_test_mark_asyncio +async def test_list_of_embedded_query_on_label(em): + a1 = em.Article( + title="Redis Basics", + tags=[em.Tag(label="redis"), em.Tag(label="tutorial")], + ) + a2 = em.Article( + title="Python Tips", + tags=[em.Tag(label="python"), em.Tag(label="tips")], + ) + await a1.save() + await a2.save() + + results = await em.Article.find(em.Article.tags.label == "redis").all() + pks = {r.pk for r in results} + assert a1.pk in pks + assert a2.pk not in pks + + +@py_test_mark_asyncio +async def test_list_of_embedded_empty_list(em): + """An empty list of embedded models should save and retrieve cleanly.""" + article = em.Article(title="No Tags", tags=[]) + await article.save() + + fetched = await em.Article.get(article.pk) + assert fetched.title == "No Tags" + assert fetched.tags == [] + + +@py_test_mark_asyncio +async def test_list_of_embedded_single_item(em): + article = em.Article(title="One Tag", tags=[em.Tag(label="solo")]) + await article.save() + + fetched = await em.Article.get(article.pk) + assert len(fetched.tags) == 1 + assert fetched.tags[0].label == "solo" + + +# =========================================================================== +# 5. Embedded model containing List[str] with full_text_search +# =========================================================================== + + +@py_test_mark_asyncio +async def test_embedded_with_list_str_fts_save_retrieve(em): + doc = em.Document( + doc_title="Advanced Redis", + section=em.Section( + heading="Indexing", + keywords=["search", "indexing", "performance"], + ), + ) + await doc.save() + + fetched = await em.Document.get(doc.pk) + assert fetched.section.heading == "Indexing" + assert set(fetched.section.keywords) == {"search", "indexing", "performance"} + + +@py_test_mark_asyncio +async def test_embedded_with_list_str_fts_tag_membership(em): + doc = em.Document( + doc_title="Tagged Doc", + section=em.Section(heading="Intro", keywords=["hello", "world"]), + ) + await doc.save() + + results = await em.Document.find(em.Document.section.keywords << ["hello"]).all() + pks = {r.pk for r in results} + assert doc.pk in pks + + +# =========================================================================== +# 6. HashModel coexisting with JsonModel; pk reference pattern +# =========================================================================== + + +@pytest.mark.skipif(not has_redisearch(), reason="requires RediSearch") +@py_test_mark_asyncio +async def test_hash_and_json_coexist_in_same_key_prefix(em): + """HashModel and JsonModel should operate independently under the same prefix.""" + tenant = em.Tenant(tenant_name="AcmeCorp") + await tenant.save() + + sub = em.Subscription(tenant_pk=str(tenant.pk), plan="enterprise") + await sub.save() + + fetched_tenant = await em.Tenant.get(tenant.pk) + fetched_sub = await em.Subscription.get(sub.pk) + + assert fetched_tenant.tenant_name == "AcmeCorp" + assert fetched_sub.plan == "enterprise" + assert fetched_sub.tenant_pk == str(tenant.pk) + + +@pytest.mark.skipif(not has_redisearch(), reason="requires RediSearch") +@py_test_mark_asyncio +async def test_json_find_by_hash_model_pk_reference(em): + """Query a JsonModel by the pk of a referenced HashModel stored as a string.""" + tenant = em.Tenant(tenant_name="Foo Inc") + await tenant.save() + + sub = em.Subscription(tenant_pk=str(tenant.pk), plan="basic") + await sub.save() + + results = await em.Subscription.find( + em.Subscription.tenant_pk == str(tenant.pk) + ).all() + assert len(results) == 1 + assert results[0].plan == "basic" + + +# =========================================================================== +# 7. All-optional embedded model +# =========================================================================== + + +@py_test_mark_asyncio +async def test_all_optional_embedded_none(em): + widget = em.Widget(widget_name="bare") + await widget.save() + + fetched = await em.Widget.get(widget.pk) + assert fetched.widget_name == "bare" + assert fetched.meta is None + + +@py_test_mark_asyncio +async def test_all_optional_embedded_partial(em): + widget = em.Widget( + widget_name="partial", meta=em.Metadata(note="important", score=None) + ) + await widget.save() + + fetched = await em.Widget.get(widget.pk) + assert fetched.meta is not None + assert fetched.meta.note == "important" + assert fetched.meta.score is None + + +@py_test_mark_asyncio +async def test_all_optional_embedded_full(em): + widget = em.Widget(widget_name="full", meta=em.Metadata(note="hello", score=9.5)) + await widget.save() + + fetched = await em.Widget.get(widget.pk) + assert fetched.meta.note == "hello" + assert abs(fetched.meta.score - 9.5) < 1e-6 + + +# =========================================================================== +# 8. Embedded model with GEO (Coordinates) +# =========================================================================== + + +@py_test_mark_asyncio +async def test_embedded_with_geo_save_retrieve(em): + event = em.Event( + event_name="Tech Summit", + venue=em.Venue( + venue_name="Convention Center", + location=Coordinates(latitude=37.7749, longitude=-122.4194), + ), + ) + await event.save() + + fetched = await em.Event.get(event.pk) + assert fetched.event_name == "Tech Summit" + assert fetched.venue.venue_name == "Convention Center" + assert fetched.venue.location is not None + + +@py_test_mark_asyncio +async def test_embedded_with_geo_none_location(em): + event = em.Event( + event_name="Virtual Event", + venue=em.Venue(venue_name="Online", location=None), + ) + await event.save() + + fetched = await em.Event.get(event.pk) + assert fetched.venue.location is None + + +# =========================================================================== +# 9. List[EmbeddedJsonModel] with city field +# =========================================================================== + + +@py_test_mark_asyncio +async def test_route_list_stops_save_retrieve(em): + route = em.Route( + route_name="West Coast", + stops=[em.Stop(city="LA"), em.Stop(city="SF"), em.Stop(city="Seattle")], + ) + await route.save() + + fetched = await em.Route.get(route.pk) + assert fetched.route_name == "West Coast" + assert len(fetched.stops) == 3 + cities = {s.city for s in fetched.stops} + assert cities == {"LA", "SF", "Seattle"} + + +@py_test_mark_asyncio +async def test_route_list_query_on_stop_city(em): + r1 = em.Route( + route_name="US Route", + stops=[em.Stop(city="NYC"), em.Stop(city="Boston")], + ) + r2 = em.Route( + route_name="EU Route", + stops=[em.Stop(city="Paris"), em.Stop(city="Berlin")], + ) + await r1.save() + await r2.save() + + results = await em.Route.find(em.Route.stops.city == "Paris").all() + pks = {r.pk for r in results} + assert r2.pk in pks + assert r1.pk not in pks + + +# =========================================================================== +# 10. Update operations on embedded fields via __ path notation +# =========================================================================== + + +@py_test_mark_asyncio +async def test_update_embedded_field_via_double_underscore(em): + profile = em.Profile( + username="updatable", + skills=em.Skills(primary="Java"), + employment=em.Employment(employer="Oracle", role="Dev", years=10), + education=em.Education(institution="IIT", degree="BTech", graduated=2013), + ) + await profile.save() + + await profile.update(employment__employer="SAP") + fetched = await em.Profile.get(profile.pk) + assert fetched.employment.employer == "SAP" + + +@py_test_mark_asyncio +async def test_update_deep_embedded_field(em): + contact = em.ContactDetails(phone="000", email="before@x.com") + office = em.Office(building="Old", floor=1, contact=contact) + dept = em.Department(name="IT", budget=100_000.0, office=office) + company = em.Company(name="Updatable Corp", industry="Finance", department=dept) + await company.save() + + # Update the nested contact email + await company.update(department__office__contact__email="after@x.com") + fetched = await em.Company.get(company.pk) + assert fetched.department.office.contact.email == "after@x.com" + + +@py_test_mark_asyncio +async def test_update_optional_embedded_to_new_value(em): + person = em.Person(name="UpdateMe") + await person.save() + assert person.address is None + + await person.update(address=em.Address(city="Madrid", country="ES")) + fetched = await em.Person.get(person.pk) + assert fetched.address.city == "Madrid" + + await person.update(address__city="Barcelona") + fetched2 = await em.Person.get(person.pk) + assert fetched2.address.city == "Barcelona" + assert fetched2.address.country == "ES" + + +# =========================================================================== +# 11. get_many with embedded structures +# =========================================================================== + + +@py_test_mark_asyncio +async def test_get_many_with_embedded(em): + profiles = [] + for i in range(5): + p = em.Profile( + username=f"bulk_user_{i}", + skills=em.Skills(primary=f"Lang{i}"), + employment=em.Employment(employer=f"Employer{i}", role="Dev", years=i), + education=em.Education( + institution=f"Uni{i}", degree="BSc", graduated=2020 + i + ), + ) + await p.save() + profiles.append(p) + + pks = [p.pk for p in profiles] + results = await em.Profile.get_many(pks) + assert len(results) == 5 + usernames = {r.username for r in results} + assert usernames == {f"bulk_user_{i}" for i in range(5)} + + +# =========================================================================== +# 12. Pipeline save with embedded models +# =========================================================================== + + +@py_test_mark_asyncio +async def test_pipeline_save_with_embedded(em): + pipeline = em.Profile.Meta.database.pipeline() + + p1 = em.Profile( + username="pipe_user_1", + skills=em.Skills(primary="Ruby"), + employment=em.Employment(employer="Shopify", role="BE", years=4), + education=em.Education(institution="Waterloo", degree="BSc", graduated=2019), + ) + p2 = em.Profile( + username="pipe_user_2", + skills=em.Skills(primary="PHP"), + employment=em.Employment(employer="WordPress", role="FE", years=3), + education=em.Education(institution="Toronto", degree="BSc", graduated=2020), + ) + await p1.save(pipeline=pipeline) + await p2.save(pipeline=pipeline) + await pipeline.execute() + + fetched1 = await em.Profile.get(p1.pk) + fetched2 = await em.Profile.get(p2.pk) + assert fetched1.username == "pipe_user_1" + assert fetched2.username == "pipe_user_2" + assert fetched1.skills.primary == "Ruby" + assert fetched2.skills.primary == "PHP" + + +# =========================================================================== +# 13. Schema correctness assertions +# =========================================================================== + + +def test_multiple_embedded_schema_contains_all_prefixed_fields(em): + schema = em.Profile.redisearch_schema() + # Each embedded model should have its own prefixed index entries + assert "$.skills.primary" in schema + assert "$.employment.employer" in schema + assert "$.employment.role" in schema + assert "$.education.institution" in schema + assert "$.education.degree" in schema + + +def test_deep_embed_schema_contains_level4_paths(em): + schema = em.Company.redisearch_schema() + assert "$.department.name" in schema + assert "$.department.office.building" in schema + assert "$.department.office.contact.phone" in schema + assert "$.department.office.contact.email" in schema + + +# =========================================================================== +# 14. Edge case: saving the same model instance twice (idempotent update) +# =========================================================================== + + +@py_test_mark_asyncio +async def test_save_twice_preserves_final_state(em): + person = em.Person(name="Dup", address=em.Address(city="Oslo", country="NO")) + await person.save() + person.name = "DupUpdated" + person.address.city = "Bergen" + await person.save() + + fetched = await em.Person.get(person.pk) + assert fetched.name == "DupUpdated" + assert fetched.address.city == "Bergen" + + +# =========================================================================== +# 15. Edge case: two separate models with identically-named embedded fields +# should not cross-contaminate queries (regression for prefix-sharing bug) +# =========================================================================== + + +@py_test_mark_asyncio +async def test_or_query_same_field_name_two_embedded_models(em): + """ + Profile.skills.primary and Profile.employment.role are different fields + but both named 'primary'/'role' in their respective embedded models. + An OR query should produce the correct per-model prefixes. + """ + p1 = em.Profile( + username="prefix_test_1", + skills=em.Skills(primary="Go"), + employment=em.Employment(employer="HashiCorp", role="SRE", years=5), + education=em.Education(institution="McGill", degree="BSc", graduated=2018), + ) + p2 = em.Profile( + username="prefix_test_2", + skills=em.Skills(primary="Rust"), + employment=em.Employment(employer="Cloudflare", role="Network", years=3), + education=em.Education(institution="McGill", degree="BSc", graduated=2019), + ) + await p1.save() + await p2.save() + + # OR on two unrelated embedded fields + results = await em.Profile.find( + (em.Profile.skills.primary == "Go") + | (em.Profile.employment.employer == "Cloudflare") + ).all() + pks = {r.pk for r in results} + assert p1.pk in pks + assert p2.pk in pks + + +# =========================================================================== +# 16. HashModel basic CRUD (sanity check alongside JsonModel tests) +# =========================================================================== + + +@pytest.mark.skipif(not has_redisearch(), reason="requires RediSearch") +@py_test_mark_asyncio +async def test_hash_model_basic_crud(em): + t1 = em.Tenant(tenant_name="Tenant A") + await t1.save() + + fetched = await em.Tenant.get(t1.pk) + assert fetched.tenant_name == "Tenant A" + + t1.tenant_name = "Tenant A Updated" + await t1.save() + + updated = await em.Tenant.get(t1.pk) + assert updated.tenant_name == "Tenant A Updated" + + await em.Tenant.delete(t1.pk) + with pytest.raises(NotFoundError): + await em.Tenant.get(t1.pk) + + +@pytest.mark.skipif(not has_redisearch(), reason="requires RediSearch") +@py_test_mark_asyncio +async def test_hash_model_query_alongside_json_model(em): + """Both models share the same Redis instance; queries must be isolated.""" + t1 = em.Tenant(tenant_name="QueryTenantX") + await t1.save() + + sub = em.Subscription(tenant_pk=str(t1.pk), plan="pro") + await sub.save() + + # Searching for the HashModel by its own field + hash_results = await em.Tenant.find(em.Tenant.tenant_name == "QueryTenantX").all() + assert any(r.pk == t1.pk for r in hash_results) + + # Searching for the JsonModel by the referenced pk string + json_results = await em.Subscription.find( + em.Subscription.tenant_pk == str(t1.pk) + ).all() + assert any(r.pk == sub.pk for r in json_results) diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 01c3736..a517f4b 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -24,6 +24,7 @@ QueryNotSupportedError, RedisModelError, ) +from aredis_om.model.model import SINGLE_VALUE_TAG_FIELD_SEPARATOR from tests._compat import EmailStr, PositiveInt, ValidationError from tests._sync_redis import has_redis_json @@ -845,9 +846,9 @@ class SortableTarotWitch(m.BaseJsonModel): with pytest.raises(RedisModelError): class SortableFullTextSearchAlchemicalWitch(m.BaseJsonModel): - # We don't support indexing a list of strings for full-text search - # queries. Support for this feature is not planned. - potions: List[str] = Field(index=True, full_text_search=True) + # Sorting multi-value fields is not supported, including when the + # same field is also indexed for full-text search. + potions: List[str] = Field(index=True, full_text_search=True, sortable=True) with pytest.raises(RedisModelError): @@ -886,6 +887,35 @@ class TarotWitch(m.BaseJsonModel): assert actual == [witch] +@py_test_mark_asyncio +async def test_string_list_field_allows_full_text_search(m): + class AlchemicalWitch(m.BaseJsonModel): + potions: List[str] = Field(index=True, full_text_search=True) + + assert ( + f"$.potions[*] AS potions TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} " + "$.potions[*] AS potions_fts TEXT" in AlchemicalWitch.redisearch_schema() + ) + + await Migrator().run() + + old_pks = [pk async for pk in await AlchemicalWitch.all_pks()] + for pk in old_pks: + await AlchemicalWitch.delete(pk) + + first = AlchemicalWitch(potions=["healing", "mana"]) + second = AlchemicalWitch(potions=["invisibility", "speed"]) + await first.save() + await second.save() + + assert await AlchemicalWitch.find(AlchemicalWitch.potions << ["mana"]).all() == [ + first + ] + assert await AlchemicalWitch.find( + AlchemicalWitch.potions % "invisibility" + ).all() == [second] + + @py_test_mark_asyncio async def test_allows_dataclasses(m): @dataclasses.dataclass diff --git a/tests/test_regressions.py b/tests/test_regressions.py index a4e1e72..e7b23a2 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -13,6 +13,7 @@ from aredis_om.connections import get_redis_connection from aredis_om.model import model as model_module from aredis_om.model.cli import migrate as migrate_cli_module +from aredis_om.model.migrations import migrator as migrator_module from aredis_om.model.model import ( Expression, convert_datetime_to_timestamp, @@ -207,6 +208,45 @@ def safe_run(coro): assert state == {"detected": 1, "ran": 1} +@py_test_mark_asyncio +async def test_cluster_create_index_targets_one_random_node(monkeypatch): + calls = [] + writes = [] + + class MissingIndexStub: + async def info(self): + raise migrator_module.redis.ResponseError("missing index") + + class FakeClusterConn: + def ft(self, _index_name): + return MissingIndexStub() + + async def execute_command(self, *args, **kwargs): + calls.append((args, kwargs)) + return "OK" + + async def set(self, key, value): + writes.append((key, value)) + return True + + conn = FakeClusterConn() + + await migrator_module._create_index_cluster( + conn, + "test-index", + "ON HASH PREFIX 1 test: SCHEMA name TAG", + "schema-hash", + ) + + assert calls == [ + ( + ("ft.create", "test-index", "ON", "HASH", "PREFIX", "1", "test:", "SCHEMA", "name", "TAG"), + {"target_nodes": migrator_module.redis.RedisCluster.RANDOM}, + ) + ] + assert writes == [("test-index:hash", "schema-hash")] + + @pytest.mark.skipif(not HAS_REDISEARCH, reason="requires RediSearch") @py_test_mark_asyncio async def test_aggregate_ct_handles_decode_response_strings(key_prefix, redis):