From 44e67460f6d939f9f1389fc83cbdafc94b08dd32 Mon Sep 17 00:00:00 2001 From: James Geiger Date: Tue, 18 Mar 2025 23:34:11 -0500 Subject: [PATCH 1/5] Implements enum datatype in OGM --- gqlalchemy/__init__.py | 1 + gqlalchemy/models.py | 49 ++++++++++++++++++++++++++- gqlalchemy/vendors/database_client.py | 28 +++++++++++++-- gqlalchemy/vendors/memgraph.py | 29 ++++++++++++++++ 4 files changed, 104 insertions(+), 3 deletions(-) diff --git a/gqlalchemy/__init__.py b/gqlalchemy/__init__.py index bccf728d..be797d06 100644 --- a/gqlalchemy/__init__.py +++ b/gqlalchemy/__init__.py @@ -20,6 +20,7 @@ MemgraphConstraintExists, MemgraphConstraintUnique, MemgraphIndex, + MemgraphEnum, MemgraphKafkaStream, MemgraphPulsarStream, MemgraphTrigger, diff --git a/gqlalchemy/models.py b/gqlalchemy/models.py index c60d0dc5..73911c7f 100644 --- a/gqlalchemy/models.py +++ b/gqlalchemy/models.py @@ -17,7 +17,9 @@ from dataclasses import dataclass from datetime import datetime, date, time, timedelta from enum import Enum +import json from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from enum import Enum, EnumMeta from pydantic.v1 import BaseModel, Extra, Field, PrivateAttr # noqa F401 @@ -57,7 +59,30 @@ def _format_timedelta(duration: timedelta) -> str: return f"P{days}DT{hours}H{minutes}M{remainder_sec}S" +class GraphEnum(ABC): + def __init__(self, enum): + if not isinstance(enum, (Enum, EnumMeta)): + raise TypeError() + + self.enum = enum if isinstance(enum, Enum) else None + self.cls = enum.__class__ if isinstance(enum, Enum) else enum + + @property + def name(self): + return self.cls.__name__ + + @abstractmethod + def _to_cypher(self): + pass + +class MemgraphEnum(GraphEnum): + def _to_cypher(self): + return f"{{ {', '.join(self.cls._member_names_)} }}" + + def __repr__(self): + return f"" if self.enum is None else f'{self.name}::{self.enum.name}' + class TriggerEventType: """An enum representing types of trigger events.""" @@ -307,6 +332,17 @@ class GraphObject(BaseModel): class Config: extra = Extra.allow + def __init__(self, **data): + for field in self.__class__.__fields__: + attrs = self.__class__.__fields__[field] + cls = self.__fields__[field].type_ + if issubclass(cls, Enum) and not attrs.get("enum", False): + value = data.get(field) + if isinstance(value, dict): + member = value.get("__value").split('::')[1] + data[field] = cls[member].value + super().__init__(**data) + def __init_subclass__(cls, type=None, label=None, labels=None, index=None, db=None): """Stores the subclass by type if type is specified, or by class name when instantiating a subclass. @@ -371,6 +407,8 @@ def escape_value( return repr(value) elif value_type == float: return repr(value) + elif isinstance(value, Enum): + return repr(MemgraphEnum(value)) elif isinstance(value, str): return repr(value) if value.isprintable() else rf"'{value}'" elif isinstance(value, list): @@ -445,7 +483,11 @@ def _get_cypher_set_properties(self, variable_name: str) -> str: cypher_set_properties = [] for field in self.__fields__: attributes = self.__fields__[field].field_info.extra - value = getattr(self, field) + cls = self.__fields__[field].type_ + if issubclass(cls, Enum) and not attributes.get("enum", False): + value = getattr(self, field).value + else: + value = getattr(self, field) if value is not None and not attributes.get("on_disk", False): cypher_set_properties.append(f" SET {variable_name}.{field} = {self.escape_value(value)}") @@ -521,12 +563,17 @@ def get_base_labels() -> Set[str]: for field in cls.__fields__: attrs = cls.__fields__[field].field_info.extra field_type = cls.__fields__[field].type_.__name__ + field_cls = cls.__fields__[field].type_ label = attrs.get("label", cls.label) skip_constraints = False if db is None: db = attrs.get("db") + # TODO: Implement enum creation and value add + if issubclass(field_cls, Enum) and attrs.get("enum", False): + pass + for constraint in FieldAttrsConstants.list(): if constraint in attrs and db is None: base = field_in_superclass(field, constraint) diff --git a/gqlalchemy/vendors/database_client.py b/gqlalchemy/vendors/database_client.py index df907f67..58f40528 100644 --- a/gqlalchemy/vendors/database_client.py +++ b/gqlalchemy/vendors/database_client.py @@ -20,11 +20,11 @@ from gqlalchemy.models import ( Constraint, Index, + GraphEnum, Node, - Relationship, + Relationship ) - class DatabaseClient(ABC): def __init__( self, @@ -127,6 +127,30 @@ def ensure_constraints( self.drop_constraint(obsolete_constraints) for missing_constraint in new_constraints.difference(old_constraints): self.create_constraint(missing_constraint) + + @abstractmethod + def create_enum(self, enum: GraphEnum) -> None: + pass + + @abstractmethod + def get_enums(self) -> List[GraphEnum]: + """Returns a list of all enums defined in the database.""" + pass + + @abstractmethod + def ensure_enums(self, indexes: List[GraphEnum]) -> None: + """Ensures that database enums match input enums.""" + pass + + @abstractmethod + def drop_enum(self, enum: GraphEnum) -> None: + """Drops a single enum in the database.""" + pass + + @abstractmethod + def drop_enums(self) -> None: + """Drops all enums in the database""" + pass def drop_database(self): """Drops database by removing all nodes and edges.""" diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index fb5fb7a9..15b218e4 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -16,6 +16,7 @@ import os import sqlite3 from typing import List, Optional, Union +import warnings from gqlalchemy.connection import Connection, MemgraphConnection from gqlalchemy.disk_storage import OnDiskPropertyDatabase @@ -24,6 +25,7 @@ GQLAlchemyFileNotFoundError, GQLAlchemyOnDiskPropertyDatabaseNotDefinedError, GQLAlchemyUniquenessConstraintError, + GQLAlchemyWarning, ) from gqlalchemy.models import ( MemgraphConstraintExists, @@ -31,6 +33,7 @@ MemgraphIndex, MemgraphStream, MemgraphTrigger, + MemgraphEnum, Node, Relationship, ) @@ -166,6 +169,32 @@ def get_constraints( ) ) return constraints + + def create_enum(self, graph_enum: MemgraphEnum) -> None: + query = f"CREATE ENUM {graph_enum.name} VALUES {graph_enum._to_cypher()};" + self.execute(query) + + def get_enums(self) -> List[MemgraphEnum]: + """Returns a list of all enums defined in the database.""" + enums: List[MemgraphEnum] = [] + for result in self.execute_and_fetch("SHOW ENUMS;"): + enums.append(MemgraphEnum(Enum(result['Enum Name'], result['Enum Values']))) + return enums + + def ensure_enums(self, graph_enums: List[MemgraphEnum]) -> None: + """Ensures that database enums match input enums.""" + current_enums = set(self.get_enums()) + new_enums = set(graph_enums) + for obsolete_enum in current_enums.difference(new_enums): + warnings.warn(GQLAlchemyWarning(f"DROP ENUM not yet implemented. Enum {obsolete_enum.__name__} is persisted in the database.")) + for missing_enum in new_enums.difference(current_enums): + self.create_enum(missing_enum) + + def drop_enum(self, graph_enum: Enum): + raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.__name__} is persisted in the database.") + + def drop_enums(self, graph_enums: List[Enum]): + raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database.") def get_exists_constraints( self, From 71eb4f5fef8363e4f3fabb153a823944f13df265 Mon Sep 17 00:00:00 2001 From: James Geiger Date: Tue, 25 Mar 2025 22:48:15 -0500 Subject: [PATCH 2/5] Defines new abstract methods and corresponding overrides for enums on db vendors. --- gqlalchemy/vendors/database_client.py | 4 ++-- gqlalchemy/vendors/memgraph.py | 22 ++++++++++------------ gqlalchemy/vendors/neo4j.py | 19 +++++++++++++++++++ 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/gqlalchemy/vendors/database_client.py b/gqlalchemy/vendors/database_client.py index 58f40528..7e8f9f76 100644 --- a/gqlalchemy/vendors/database_client.py +++ b/gqlalchemy/vendors/database_client.py @@ -138,8 +138,8 @@ def get_enums(self) -> List[GraphEnum]: pass @abstractmethod - def ensure_enums(self, indexes: List[GraphEnum]) -> None: - """Ensures that database enums match input enums.""" + def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: + """Ensures that database enum matches input enum.""" pass @abstractmethod diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index 15b218e4..9d2d7e9a 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -181,19 +181,17 @@ def get_enums(self) -> List[MemgraphEnum]: enums.append(MemgraphEnum(Enum(result['Enum Name'], result['Enum Values']))) return enums - def ensure_enums(self, graph_enums: List[MemgraphEnum]) -> None: - """Ensures that database enums match input enums.""" - current_enums = set(self.get_enums()) - new_enums = set(graph_enums) - for obsolete_enum in current_enums.difference(new_enums): - warnings.warn(GQLAlchemyWarning(f"DROP ENUM not yet implemented. Enum {obsolete_enum.__name__} is persisted in the database.")) - for missing_enum in new_enums.difference(current_enums): - self.create_enum(missing_enum) - - def drop_enum(self, graph_enum: Enum): - raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.__name__} is persisted in the database.") + def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None: + """Ensures that database enum matches input enum.""" + for value in new.members: + if value not in existing.members: + query = f"ALTER ENUM {existing.name} ADD VALUE {value};" + self.execute(query) + + def drop_enum(self, graph_enum: MemgraphEnum): + raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.name} is persisted in the database.") - def drop_enums(self, graph_enums: List[Enum]): + def drop_enums(self, graph_enums: List[MemgraphEnum]): raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database.") def get_exists_constraints( diff --git a/gqlalchemy/vendors/neo4j.py b/gqlalchemy/vendors/neo4j.py index a5b5aa58..569963b5 100644 --- a/gqlalchemy/vendors/neo4j.py +++ b/gqlalchemy/vendors/neo4j.py @@ -14,6 +14,7 @@ import os from typing import List, Optional, Union +from enum import Enum from gqlalchemy.connection import Connection, Neo4jConnection from gqlalchemy.exceptions import ( @@ -24,6 +25,7 @@ Neo4jConstraintExists, Neo4jConstraintUnique, Neo4jIndex, + GraphEnum, Node, Relationship, ) @@ -99,6 +101,23 @@ def ensure_indexes(self, indexes: List[Neo4jIndex]) -> None: for missing_index in new_indexes.difference(old_indexes): self.create_index(missing_index) + def create_enum(self, graph_enum: GraphEnum) -> None: + raise GQLAlchemyError(f"CREATE ENUM not yet implemented in Neo4j.") + + def get_enums(self) -> List[GraphEnum]: + """Returns a list of all enums defined in the database.""" + raise GQLAlchemyError(f"SHOW ENUMS not yet implemented in Neo4j.") + + def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: + """Ensures that database enum matches input enum.""" + raise GQLAlchemyError(f"ALTER ENUM not yet implemented in Neo4j.") + + def drop_enum(self, graph_enum: GraphEnum): + raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.") + + def drop_enums(self, graph_enums: List[GraphEnum]): + raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.") + def get_constraints( self, ) -> List[Union[Neo4jConstraintExists, Neo4jConstraintUnique]]: From 9ee5de4caec83e7300c53c5e23315508723c7665 Mon Sep 17 00:00:00 2001 From: James Geiger Date: Thu, 27 Mar 2025 13:50:23 -0500 Subject: [PATCH 3/5] Adds creating enum on Node/Relationship definition --- gqlalchemy/models.py | 45 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/gqlalchemy/models.py b/gqlalchemy/models.py index 6dd0c89a..80c89389 100644 --- a/gqlalchemy/models.py +++ b/gqlalchemy/models.py @@ -72,6 +72,10 @@ def __init__(self, enum): def name(self): return self.cls.__name__ + @property + def members(self): + return self.cls.__members__ + @abstractmethod def _to_cypher(self): pass @@ -334,7 +338,7 @@ class Config: def __init__(self, **data): for field in self.__class__.__fields__: - attrs = self.__class__.__fields__[field] + attrs = self.__class__.__fields__[field].field_info.extra cls = self.__fields__[field].type_ if issubclass(cls, Enum) and not attrs.get("enum", False): value = data.get(field) @@ -553,6 +557,9 @@ def get_base_labels() -> Set[str]: cls.labels = get_base_labels().union({cls.label}, kwargs.get("labels", set())) db = kwargs.get("db") + + cls.enums = None + if cls.index is True: if db is None: raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls) @@ -570,9 +577,17 @@ def get_base_labels() -> Set[str]: if db is None: db = attrs.get("db") - # TODO: Implement enum creation and value add if issubclass(field_cls, Enum) and attrs.get("enum", False): - pass + if db is None: + raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls) + if cls.enums is None: + cls.enums = db.get_enums() + enum_names = [x.name for x in cls.enums] + if(field_cls.__name__ in enum_names): + existing = cls.enums[enum_names.index(field_cls.__name__)] + db.sync_enum(existing, MemgraphEnum(field_cls)) + else: + db.create_enum(MemgraphEnum(field_cls)) for constraint in FieldAttrsConstants.list(): if constraint in attrs and db is None: @@ -709,6 +724,30 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 if name != "Relationship": cls.type = kwargs.get("type", name) + db = kwargs.get("db") + + cls.enums = None + + for field in cls.__fields__: + attrs = cls.__fields__[field].field_info.extra + field_type = cls.__fields__[field].type_.__name__ + field_cls = cls.__fields__[field].type_ + + if db is None: + db = attrs.get("db") + + if issubclass(field_cls, Enum) and attrs.get("enum", False): + if db is None: + raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls) + if cls.enums is None: + cls.enums = db.get_enums() + enum_names = [x.name for x in cls.enums] + if(field_type in enum_names): + existing = cls.enums[enum_names.index(field_type)] + db.sync_enum(existing, MemgraphEnum(field_cls)) + else: + db.create_enum(MemgraphEnum(field_cls)) + return cls From 7cc252a471514c456c0733d24cf90ef870a5f797 Mon Sep 17 00:00:00 2001 From: James Geiger Date: Thu, 27 Mar 2025 13:50:35 -0500 Subject: [PATCH 4/5] Updates documentation on usage of enums. --- docs/how-to-guides/ogm.md | 44 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/docs/how-to-guides/ogm.md b/docs/how-to-guides/ogm.md index a75da273..8a7081c4 100644 --- a/docs/how-to-guides/ogm.md +++ b/docs/how-to-guides/ogm.md @@ -395,6 +395,37 @@ To check which constraints have been created, run: print(db.get_constraints()) ``` +## Using enums + +Memgraph's built-in [enum data type](https://memgraph.com/docs/fundamentals/data-types#enum) can be utilized on your GQLAlchemy OGM models. GQLAlchemy's enum implementation extends Python's [enum support](https://docs.python.org/3.11/library/enum.html). + +First, create an enum. + +```python +from enum import Enum + +class SubscriptionType(Enum): + FREE = 1 + BASIC = 2 + EXTENDED = 3 +``` + +Then, use the defined enum class in your model definition. Using the `Field` class, set the `enum` attribute to `True`. This will indicate that GQLAlchemy should treat the property value stored as a Memgraph enum. If the enum does not exist in the database, it will be created. + +```python +class User(Node): + id: str = Field(index=True, db=db) + username: str + subscription: SubscriptionType = Field(enum=True, db=db) +``` + +Enum types may be defined for properties on Nodes and Relationships. + +!!! info + If the `Field` class specification on the property isn't specified, or if `enum` is explicitly set to `False`, GQLAlchemy will use the `value` of the enum member when serializing to a Cypher query. A corresponding enum will not be created in the database. + + This functionality allows for flexiblity when using the Python `Enum` class, and would, for instance, respect an overridden `__getattribute__` method to customize the value passed to Cypher. + ## Full code example The above mentioned examples can be merged into a working code example which you can run. Here is the code: @@ -402,12 +433,19 @@ The above mentioned examples can be merged into a working code example which you ```python from gqlalchemy import Memgraph, Node, Relationship, Field from typing import Optional +from enum import Enum db = Memgraph() +class SubscriptionType(Enum): + FREE = 1 + BASIC = 2 + EXTENDED = 3 + class User(Node): id: str = Field(index=True, db=db) username: str = Field(exists=True, db=db) + subscription: SubscriptionType = Field(enum=True, db=db) class Streamer(User): id: str @@ -423,8 +461,8 @@ class ChatsWith(Relationship, type="CHATS_WITH"): class Speaks(Relationship, type="SPEAKS"): since: Optional[str] -john = User(id="1", username="John").save(db) -jane = Streamer(id="2", username="janedoe", followers=111).save(db) +john = User(id="1", username="John", subscription=SubscriptionType(1)).save(db) +jane = Streamer(id="2", username="janedoe", subscription=SubscriptionType(3), followers=111).save(db) language = Language(name="en").save(db) ChatsWith( @@ -449,7 +487,7 @@ try: streamer = Streamer(id="3").load(db=db) except: print("Creating new Streamer node in the database.") - streamer = Streamer(id="3", username="anne", followers=222).save(db=db) + streamer = Streamer(id="3", username="anne", subscription=SubscriptionType(2), followers=222).save(db=db) try: speaks = Speaks(_start_node_id=streamer._id, _end_node_id=language._id).load(db) From bf7fee23deb0fe3b49151bad1cd2f72c88d29f05 Mon Sep 17 00:00:00 2001 From: James Geiger Date: Fri, 22 Aug 2025 12:44:59 -0500 Subject: [PATCH 5/5] Adding tests and updating clients, vendors --- gqlalchemy/models.py | 24 +++++++++++++----------- gqlalchemy/vendors/database_client.py | 15 +++++---------- gqlalchemy/vendors/memgraph.py | 14 +++++++------- gqlalchemy/vendors/neo4j.py | 7 +++---- tests/ogm/test_custom_fields.py | 16 ++++++++++++++++ 5 files changed, 44 insertions(+), 32 deletions(-) diff --git a/gqlalchemy/models.py b/gqlalchemy/models.py index 80c89389..aa0c80d2 100644 --- a/gqlalchemy/models.py +++ b/gqlalchemy/models.py @@ -16,7 +16,6 @@ from collections import defaultdict from dataclasses import dataclass from datetime import datetime, date, time, timedelta -from enum import Enum import json from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from enum import Enum, EnumMeta @@ -59,34 +58,37 @@ def _format_timedelta(duration: timedelta) -> str: return f"P{days}DT{hours}H{minutes}M{remainder_sec}S" + class GraphEnum(ABC): def __init__(self, enum): if not isinstance(enum, (Enum, EnumMeta)): raise TypeError() - + self.enum = enum if isinstance(enum, Enum) else None self.cls = enum.__class__ if isinstance(enum, Enum) else enum - + @property def name(self): return self.cls.__name__ - + @property def members(self): return self.cls.__members__ - + @abstractmethod def _to_cypher(self): pass + class MemgraphEnum(GraphEnum): def _to_cypher(self): return f"{{ {', '.join(self.cls._member_names_)} }}" - + def __repr__(self): - return f"" if self.enum is None else f'{self.name}::{self.enum.name}' - + return f"" if self.enum is None else f"{self.name}::{self.enum.name}" + + class TriggerEventType: """An enum representing types of trigger events.""" @@ -343,7 +345,7 @@ def __init__(self, **data): if issubclass(cls, Enum) and not attrs.get("enum", False): value = data.get(field) if isinstance(value, dict): - member = value.get("__value").split('::')[1] + member = value.get("__value").split("::")[1] data[field] = cls[member].value super().__init__(**data) @@ -583,7 +585,7 @@ def get_base_labels() -> Set[str]: if cls.enums is None: cls.enums = db.get_enums() enum_names = [x.name for x in cls.enums] - if(field_cls.__name__ in enum_names): + if field_cls.__name__ in enum_names: existing = cls.enums[enum_names.index(field_cls.__name__)] db.sync_enum(existing, MemgraphEnum(field_cls)) else: @@ -742,7 +744,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 if cls.enums is None: cls.enums = db.get_enums() enum_names = [x.name for x in cls.enums] - if(field_type in enum_names): + if field_type in enum_names: existing = cls.enums[enum_names.index(field_type)] db.sync_enum(existing, MemgraphEnum(field_cls)) else: diff --git a/gqlalchemy/vendors/database_client.py b/gqlalchemy/vendors/database_client.py index 7e8f9f76..182c9cda 100644 --- a/gqlalchemy/vendors/database_client.py +++ b/gqlalchemy/vendors/database_client.py @@ -17,13 +17,8 @@ from gqlalchemy.connection import Connection from gqlalchemy.exceptions import GQLAlchemyError -from gqlalchemy.models import ( - Constraint, - Index, - GraphEnum, - Node, - Relationship -) +from gqlalchemy.models import Constraint, Index, GraphEnum, Node, Relationship + class DatabaseClient(ABC): def __init__( @@ -127,7 +122,7 @@ def ensure_constraints( self.drop_constraint(obsolete_constraints) for missing_constraint in new_constraints.difference(old_constraints): self.create_constraint(missing_constraint) - + @abstractmethod def create_enum(self, enum: GraphEnum) -> None: pass @@ -136,7 +131,7 @@ def create_enum(self, enum: GraphEnum) -> None: def get_enums(self) -> List[GraphEnum]: """Returns a list of all enums defined in the database.""" pass - + @abstractmethod def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: """Ensures that database enum matches input enum.""" @@ -146,7 +141,7 @@ def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: def drop_enum(self, enum: GraphEnum) -> None: """Drops a single enum in the database.""" pass - + @abstractmethod def drop_enums(self) -> None: """Drops all enums in the database""" diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index 9d2d7e9a..ba2a3f82 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -16,7 +16,6 @@ import os import sqlite3 from typing import List, Optional, Union -import warnings from gqlalchemy.connection import Connection, MemgraphConnection from gqlalchemy.disk_storage import OnDiskPropertyDatabase @@ -25,7 +24,6 @@ GQLAlchemyFileNotFoundError, GQLAlchemyOnDiskPropertyDatabaseNotDefinedError, GQLAlchemyUniquenessConstraintError, - GQLAlchemyWarning, ) from gqlalchemy.models import ( MemgraphConstraintExists, @@ -169,7 +167,7 @@ def get_constraints( ) ) return constraints - + def create_enum(self, graph_enum: MemgraphEnum) -> None: query = f"CREATE ENUM {graph_enum.name} VALUES {graph_enum._to_cypher()};" self.execute(query) @@ -178,9 +176,9 @@ def get_enums(self) -> List[MemgraphEnum]: """Returns a list of all enums defined in the database.""" enums: List[MemgraphEnum] = [] for result in self.execute_and_fetch("SHOW ENUMS;"): - enums.append(MemgraphEnum(Enum(result['Enum Name'], result['Enum Values']))) + enums.append(MemgraphEnum(Enum(result["Enum Name"], result["Enum Values"]))) return enums - + def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None: """Ensures that database enum matches input enum.""" for value in new.members: @@ -190,9 +188,11 @@ def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None: def drop_enum(self, graph_enum: MemgraphEnum): raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.name} is persisted in the database.") - + def drop_enums(self, graph_enums: List[MemgraphEnum]): - raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database.") + raise GQLAlchemyError( + f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database." + ) def get_exists_constraints( self, diff --git a/gqlalchemy/vendors/neo4j.py b/gqlalchemy/vendors/neo4j.py index 569963b5..5fc0b2c2 100644 --- a/gqlalchemy/vendors/neo4j.py +++ b/gqlalchemy/vendors/neo4j.py @@ -14,7 +14,6 @@ import os from typing import List, Optional, Union -from enum import Enum from gqlalchemy.connection import Connection, Neo4jConnection from gqlalchemy.exceptions import ( @@ -107,14 +106,14 @@ def create_enum(self, graph_enum: GraphEnum) -> None: def get_enums(self) -> List[GraphEnum]: """Returns a list of all enums defined in the database.""" raise GQLAlchemyError(f"SHOW ENUMS not yet implemented in Neo4j.") - + def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None: """Ensures that database enum matches input enum.""" raise GQLAlchemyError(f"ALTER ENUM not yet implemented in Neo4j.") - + def drop_enum(self, graph_enum: GraphEnum): raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.") - + def drop_enums(self, graph_enums: List[GraphEnum]): raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.") diff --git a/tests/ogm/test_custom_fields.py b/tests/ogm/test_custom_fields.py index d346300d..d4987e1a 100644 --- a/tests/ogm/test_custom_fields.py +++ b/tests/ogm/test_custom_fields.py @@ -13,10 +13,13 @@ from pydantic.v1 import Field +from enum import Enum + from gqlalchemy import ( MemgraphConstraintExists, MemgraphConstraintUnique, MemgraphIndex, + MemgraphEnum, Neo4jConstraintUnique, Neo4jIndex, Node, @@ -56,6 +59,19 @@ def test_create_index(memgraph): assert actual_constraints == [memgraph_index] +def test_create_graph_enum(memgraph): + enum1 = Enum("MgEnum", (("MEMBER1", "value1"), ("MEMBER2", "value2"), ("MEMBER3", "value3"))) + + class Node3(Node): + type: enum1 + + memgraph_enum = MemgraphEnum(enum1) + + actual_enums = memgraph.get_enums() + + assert actual_enums == [memgraph_enum] + + def test_create_constraint_unique_neo4j(neo4j): class Node2(Node): id: int = Field(db=neo4j)