diff --git a/keep/api/core/db.py b/keep/api/core/db.py index e2efea2f2b..e04e3d613c 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -53,6 +53,7 @@ custom_serialize, get_json_extract_field, get_or_create, + insert_update_conflict, ) from keep.api.core.dependencies import SINGLE_TENANT_UUID @@ -166,7 +167,7 @@ def __convert_to_uuid(value: str, should_raise: bool = False) -> UUID | None: def retry_on_db_error(f): @retry( - exceptions=(OperationalError, IntegrityError, StaleDataError), + exceptions=(OperationalError, IntegrityError, StaleDataError, NoActiveSqlTransaction), tries=3, delay=0.1, backoff=2, @@ -177,7 +178,7 @@ def retry_on_db_error(f): def wrapper(*args, **kwargs): try: return f(*args, **kwargs) - except (OperationalError, IntegrityError, StaleDataError) as e: + except (OperationalError, IntegrityError, StaleDataError, NoActiveSqlTransaction) as e: if hasattr(e, "session") and not e.session.is_active: e.session.rollback() @@ -187,6 +188,11 @@ def wrapper(*args, **kwargs): "Deadlock detected, retrying transaction", extra={"error": str(e)} ) raise # retry will catch this + elif "No active SQL transaction" in str(e): + logger.exception( + "No active SQL transaction detected, retrying transaction", extra={"error": str(e)} + ) + raise # retry will catch this else: logger.exception( f"Error while executing transaction during {f.__name__}", @@ -5683,99 +5689,33 @@ def get_last_alert_by_fingerprint( query = query.with_for_update() return session.exec(query).first() - +@retry_on_db_error def set_last_alert( - tenant_id: str, alert: Alert, session: Optional[Session] = None, max_retries=3 + tenant_id: str, alert: Alert, session: Optional[Session] = None ) -> None: fingerprint = alert.fingerprint logger.info(f"Setting last alert for `{fingerprint}`") with existed_or_new_session(session) as session: - for attempt in range(max_retries): - logger.info( - f"Attempt {attempt} to set last alert for `{fingerprint}`", - extra={ - "alert_id": alert.id, - "tenant_id": tenant_id, - "fingerprint": fingerprint, - }, - ) - try: - last_alert = get_last_alert_by_fingerprint( - tenant_id, fingerprint, session, for_update=True - ) - - # To prevent rare, but possible race condition - # For example if older alert failed to process - # and retried after new one - if last_alert and last_alert.timestamp.replace( - tzinfo=tz.UTC - ) < alert.timestamp.replace(tzinfo=tz.UTC): - - logger.info( - f"Update last alert for `{fingerprint}`: {last_alert.alert_id} -> {alert.id}", - extra={ - "alert_id": alert.id, - "tenant_id": tenant_id, - "fingerprint": fingerprint, - }, - ) - last_alert.timestamp = alert.timestamp - last_alert.alert_id = alert.id - last_alert.alert_hash = alert.alert_hash - session.add(last_alert) - - elif not last_alert: - logger.info(f"No last alert for `{fingerprint}`, creating new") - last_alert = LastAlert( - tenant_id=tenant_id, - fingerprint=alert.fingerprint, - timestamp=alert.timestamp, - first_timestamp=alert.timestamp, - alert_id=alert.id, - alert_hash=alert.alert_hash, - ) - - session.add(last_alert) - session.commit() - break - except OperationalError as ex: - if "no such savepoint" in ex.args[0]: - logger.info( - f"No such savepoint while updating lastalert for `{fingerprint}`, retry #{attempt}" - ) - session.rollback() - if attempt >= max_retries: - raise ex - continue - - if "Deadlock found" in ex.args[0]: - logger.info( - f"Deadlock found while updating lastalert for `{fingerprint}`, retry #{attempt}" - ) - session.rollback() - if attempt >= max_retries: - raise ex - continue - except NoActiveSqlTransaction: - logger.exception( - f"No active sql transaction while updating lastalert for `{fingerprint}`, retry #{attempt}", - extra={ - "alert_id": alert.id, - "tenant_id": tenant_id, - "fingerprint": fingerprint, - }, - ) - continue - logger.debug( - f"Successfully updated lastalert for `{fingerprint}`", - extra={ - "alert_id": alert.id, - "tenant_id": tenant_id, - "fingerprint": fingerprint, - }, - ) - # break the retry loop - break + insert_update_conflict(LastAlert, session, data_to_insert = { + "tenant_id":tenant_id, + "fingerprint": alert.fingerprint, + "timestamp": alert.timestamp, + "first_timestamp": alert.timestamp, + "alert_id": alert.id, + "alert_hash": alert.alert_hash, + }, data_to_update ={ + "timestamp": alert.timestamp, + "alert_id": alert.id, + "alert_hash": alert.alert_hash + }, update_newer=True) + logger.debug( + f"Successfully updated lastalert for `{fingerprint}`", + extra={ + "alert_id": alert.id, + "tenant_id": tenant_id, + "fingerprint": fingerprint, + }, + ) def set_maintenance_windows_trace(alert: Alert, maintenance_w: MaintenanceWindowRule, session: Optional[Session] = None): mw_id = str(maintenance_w.id) diff --git a/keep/api/core/db_utils.py b/keep/api/core/db_utils.py index 6d0841e7db..d213c6d582 100644 --- a/keep/api/core/db_utils.py +++ b/keep/api/core/db_utils.py @@ -21,6 +21,10 @@ from sqlalchemy.sql.ddl import CreateColumn from sqlalchemy.sql.functions import GenericFunction from sqlmodel import Session, SQLModel, create_engine, select +from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from sqlalchemy import case # This import is required to create the tables from keep.api.consts import RUNNING_IN_CLOUD_RUN @@ -199,6 +203,45 @@ def get_aggreated_field(session: Session, column_name: str, alias: str): return func.array_agg(column_name).label(alias) +def insert_update_conflict(table: SQLModel, session: Session, data_to_insert: dict, data_to_update: dict, update_newer: bool): + """ + Performs an upsert (insert or update on conflict) operation on the given table. + Args: + table (SQLModel): The table to perform the upsert on. + session (Session): The SQLModel session. + data_to_insert (dict): The data to insert. + data_to_update (dict): The data to update if a conflict occurs. + update_newer (bool): If True, update only if existing timestamp is older than new one. + """ + + if session.bind.dialect.name == "postgresql": + query = pg_insert(table).values(data_to_insert) + query = query.on_conflict_do_update( + index_elements=[col.name for col in table.__table__.primary_key.columns], + set_=data_to_update, + where=(table.timestamp < query.excluded.timestamp) if update_newer else None + ) + elif session.bind.dialect.name == "mysql": + query = mysql_insert(table).values(data_to_insert) + if update_newer: + data_to_update = { + k: case((table.timestamp < query.inserted.timestamp, v), else_=getattr(table, k)) + for k, v in data_to_update.items() + } + query = query.on_duplicate_key_update(data_to_update) + elif session.bind.dialect.name == "sqlite": + query = sqlite_insert(table).values(data_to_insert) + query = query.on_conflict_do_update( + index_elements=[col.name for col in table.__table__.primary_key.columns], + set_=data_to_update, + where=(table.timestamp < query.excluded.timestamp) if update_newer else None + ) + else: + raise NotImplementedError(f"UPSERT not supported for {session.bind.dialect.name}") + + session.exec(query) + session.commit() + class json_table(GenericFunction): inherit_cache = True