diff --git a/ibis-server/README.md b/ibis-server/README.md index b7b16922c..71e7c1a9d 100644 --- a/ibis-server/README.md +++ b/ibis-server/README.md @@ -114,6 +114,75 @@ docker compose down -v ``` +### Running Doris Tests Locally +Doris-related tests require a running Apache Doris instance. +Our GitHub CI already handles this automatically, but you must start Doris manually when running tests locally. + +Prerequisites + +- Docker & Docker Compose +- Python dependencies installed (`just install`) +- `pymysql` installed in the dev environment (already included in dev dependencies) + +#### Config Doris Cluster + +1. Start the Doris Container + +From the `ibis-server` directory: +```bash +cd tests/routers/v3/connector/doris +docker compose up -d +``` + +The container uses `apache/doris:4.0.3-all-slim` (all-in-one image with FE + BE). + +> ⚠️ The all-in-one Doris image requires sufficient memory (at least 8 GB recommended). +> If you see `MEM_ALLOC_FAILED` errors, increase Docker's memory limit. + +Wait until Doris is healthy. Check the status: +```bash +mysql -h 127.0.0.1 -P 9030 -uroot -e "SHOW BACKENDS\G" | grep "Alive" +# Alive: true +``` + +2. Update Connection Info (if needed) + +The default connection in `tests/routers/v3/connector/doris/conftest.py`: +```python +DORIS_HOST = "127.0.0.1" +DORIS_PORT = 9030 +DORIS_USER = "root" +DORIS_PASSWORD = "" +``` + + +Adjust these values if your Doris instance has different credentials. + +If you already have a remote Doris cluster, update the connection constants in `conftest.py`: +```python +DORIS_HOST = "" +DORIS_PORT = 9030 +DORIS_USER = "" +DORIS_PASSWORD = "" +``` + +#### Run Doris Tests + +Go back to the `ibis-server` directory and run: +```bash +just test doris +``` + +⚠️ Doris tests will fail if the Doris instance is not reachable. + +#### Cleanup (Local Docker) + +After tests finish: +```bash +cd tests/routers/v3/connector/doris +docker compose down -v +``` + ### Start with Python Interactive Mode Install the dependencies diff --git a/ibis-server/app/custom_sqlglot/dialects/__init__.py b/ibis-server/app/custom_sqlglot/dialects/__init__.py index b40c47a3a..ae8f3ab83 100644 --- a/ibis-server/app/custom_sqlglot/dialects/__init__.py +++ b/ibis-server/app/custom_sqlglot/dialects/__init__.py @@ -1,3 +1,4 @@ # ruff: noqa: F401 +from app.custom_sqlglot.dialects.doris import Doris from app.custom_sqlglot.dialects.mysql import MySQL diff --git a/ibis-server/app/custom_sqlglot/dialects/doris.py b/ibis-server/app/custom_sqlglot/dialects/doris.py new file mode 100644 index 000000000..3f3baea27 --- /dev/null +++ b/ibis-server/app/custom_sqlglot/dialects/doris.py @@ -0,0 +1,10 @@ +from sqlglot import exp +from sqlglot.dialects import Doris as OriginalDoris + + +class Doris(OriginalDoris): + class Generator(OriginalDoris.Generator): + TYPE_MAPPING = { + **OriginalDoris.Generator.TYPE_MAPPING, + exp.DataType.Type.VARBINARY: "BINARY", + } diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 86ecd0b48..6e10a83f0 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -53,6 +53,10 @@ class QueryMySqlDTO(QueryDTO): connection_info: ConnectionUrl | MySqlConnectionInfo = connection_info_field +class QueryDorisDTO(QueryDTO): + connection_info: DorisConnectionInfo = connection_info_field + + class QueryOracleDTO(QueryDTO): connection_info: ConnectionUrl | OracleConnectionInfo = connection_info_field @@ -322,6 +326,29 @@ class MySqlConnectionInfo(BaseConnectionInfo): ) +class DorisConnectionInfo(BaseConnectionInfo): + host: SecretStr = Field( + description="the hostname of your Doris FE", examples=["localhost"] + ) + port: SecretStr = Field( + description="the query port of your Doris FE", examples=["9030"] + ) + database: SecretStr = Field( + description="the database name of your Doris database", examples=["default"] + ) + user: SecretStr = Field( + description="the username of your Doris database", examples=["root"] + ) + password: SecretStr | None = Field( + description="the password of your Doris database", + examples=["password"], + default=None, + ) + kwargs: dict[str, str] | None = Field( + description="Additional keyword arguments to pass to PyMySQL", default=None + ) + + class PostgresConnectionInfo(BaseConnectionInfo): host: SecretStr = Field( examples=["localhost"], description="the hostname of your database" @@ -636,6 +663,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo): | ConnectionUrl | MSSqlConnectionInfo | MySqlConnectionInfo + | DorisConnectionInfo | OracleConnectionInfo | PostgresConnectionInfo | RedshiftConnectionInfo diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index c86ff6877..37af156ea 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -104,6 +104,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): self._connector = DatabricksConnector(connection_info) elif data_source == DataSource.mysql: self._connector = MySqlConnector(connection_info) + elif data_source == DataSource.doris: + self._connector = DorisConnector(connection_info) else: self._connector = IbisConnector(data_source, connection_info) @@ -357,6 +359,43 @@ def _cast_json_columns(self, result_table: Table, col_name: str) -> Table: return result_table.mutate(**{col_name: casted_col}) +class DorisConnector(IbisConnector): + """Doris connector - reuses MySQL protocol via ibis.mysql backend. + + Doris is an analytical database that is MySQL-protocol compatible. + Autocommit is forced on in get_doris_connection() because Doris may not + properly reflect the SERVER_STATUS_AUTOCOMMIT flag, which would cause + ibis's raw_sql() to wrap every query in BEGIN/ROLLBACK unnecessarily. + """ + + def __init__(self, connection_info: ConnectionInfo): + super().__init__(DataSource.doris, connection_info) + + def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table: + result_table = ibis_table + for name, dtype in ibis_table.schema().items(): + if isinstance(dtype, Decimal): + result_table = self._round_decimal_columns( + result_table=result_table, col_name=name, **kwargs + ) + elif isinstance(dtype, UUID): + result_table = self._cast_uuid_columns( + result_table=result_table, col_name=name + ) + elif isinstance(dtype, dt.JSON): + # Doris JSON columns need the same handling as MySQL + result_table = self._cast_json_columns( + result_table=result_table, col_name=name + ) + + return result_table + + def _cast_json_columns(self, result_table: Table, col_name: str) -> Table: + col = result_table[col_name] + casted_col = col.cast("string") + return result_table.mutate(**{col_name: casted_col}) + + class MSSqlConnector(IbisConnector): def __init__(self, connection_info: ConnectionInfo): super().__init__(DataSource.mssql, connection_info) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index c30aa8827..9ca9830de 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -24,6 +24,7 @@ ConnectionUrl, DatabricksServicePrincipalConnectionInfo, DatabricksTokenConnectionInfo, + DorisConnectionInfo, GcsFileConnectionInfo, LocalFileConnectionInfo, MinioFileConnectionInfo, @@ -36,6 +37,7 @@ QueryCannerDTO, QueryClickHouseDTO, QueryDatabricksDTO, + QueryDorisDTO, QueryDTO, QueryGcsFileDTO, QueryLocalFileDTO, @@ -69,6 +71,7 @@ class DataSource(StrEnum): clickhouse = auto() mssql = auto() mysql = auto() + doris = auto() oracle = auto() postgres = auto() redshift = auto() @@ -167,6 +170,8 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo: return MSSqlConnectionInfo.model_validate(data) case DataSource.mysql: return MySqlConnectionInfo.model_validate(data) + case DataSource.doris: + return DorisConnectionInfo.model_validate(data) case DataSource.oracle: return OracleConnectionInfo.model_validate(data) case DataSource.postgres: @@ -236,6 +241,7 @@ class DataSourceExtension(Enum): clickhouse = QueryClickHouseDTO mssql = QueryMSSqlDTO mysql = QueryMySqlDTO + doris = QueryDorisDTO oracle = QueryOracleDTO postgres = QueryPostgresDTO redshift = QueryRedshiftDTO @@ -397,6 +403,42 @@ def get_mysql_connection(cls, info: MySqlConnectionInfo) -> BaseBackend: **kwargs, ) + @classmethod + def get_doris_connection(cls, info: DorisConnectionInfo) -> BaseBackend: + kwargs = {} + + # utf8mb4 is the actual charset used by Doris (MySQL-compatible) + kwargs.setdefault("charset", "utf8mb4") + + if info.kwargs: + kwargs.update(info.kwargs) + # Doris is MySQL-protocol compatible, reuse ibis.mysql.connect() + connection = ibis.mysql.connect( + host=info.host.get_secret_value(), + port=int(info.port.get_secret_value()), + database=info.database.get_secret_value(), + user=info.user.get_secret_value(), + password=info.password.get_secret_value() if info.password else "", + **kwargs, + ) + # Doris does not properly reflect the SERVER_STATUS_AUTOCOMMIT flag + # in its MySQL-protocol handshake/OK packets. As a result, the + # underlying mysqlclient driver's get_autocommit() always returns + # False — even after explicitly calling autocommit(True). + # + # ibis's raw_sql() checks get_autocommit() and, when it returns + # False, wraps every query in BEGIN/ROLLBACK. Doris (an OLAP engine) + # does not support transactional SELECT inside BEGIN and will reject + # with: "This is in a transaction, only insert, update, delete, + # commit, rollback is acceptable." + # + # Fix: override get_autocommit on THIS connection instance only so + # that ibis skips the BEGIN/ROLLBACK wrapping. This is a per-object + # attribute override — it does NOT affect the MySQLdb class, other + # MySQL connections, or any other data-source driver. + connection.con.get_autocommit = lambda: True + return connection + @staticmethod def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend: return ibis.postgres.connect( diff --git a/ibis-server/app/model/metadata/doris.py b/ibis-server/app/model/metadata/doris.py new file mode 100644 index 000000000..0d91b2852 --- /dev/null +++ b/ibis-server/app/model/metadata/doris.py @@ -0,0 +1,212 @@ +import re + +from loguru import logger + +from app.model import DorisConnectionInfo +from app.model.data_source import DataSource +from app.model.metadata.dto import ( + Catalog, + Column, + Constraint, + RustWrenEngineColumnType, + Table, + TableProperties, +) +from app.model.metadata.metadata import Metadata + +# Doris-specific type mapping +# Doris is MySQL-protocol compatible but has additional types +# Reference: https://doris.apache.org/docs/sql-manual/data-types/ +DORIS_TYPE_MAPPING = { + # ── String Types ───────────────────────────── + "char": RustWrenEngineColumnType.CHAR, + "varchar": RustWrenEngineColumnType.VARCHAR, + "string": RustWrenEngineColumnType.VARCHAR, + "text": RustWrenEngineColumnType.TEXT, + "tinytext": RustWrenEngineColumnType.TEXT, + "mediumtext": RustWrenEngineColumnType.TEXT, + "longtext": RustWrenEngineColumnType.TEXT, + # ── Numeric Types ──────────────────────────── + "tinyint": RustWrenEngineColumnType.TINYINT, + "smallint": RustWrenEngineColumnType.SMALLINT, + "int": RustWrenEngineColumnType.INTEGER, + "integer": RustWrenEngineColumnType.INTEGER, + "mediumint": RustWrenEngineColumnType.INTEGER, + "bigint": RustWrenEngineColumnType.BIGINT, + "largeint": RustWrenEngineColumnType.BIGINT, + # ── Boolean Types ──────────────────────────── + "boolean": RustWrenEngineColumnType.BOOL, + "bool": RustWrenEngineColumnType.BOOL, + # ── Decimal Types ──────────────────────────── + "float": RustWrenEngineColumnType.FLOAT8, + "double": RustWrenEngineColumnType.DOUBLE, + "decimal": RustWrenEngineColumnType.DECIMAL, + "decimalv3": RustWrenEngineColumnType.DECIMAL, + "numeric": RustWrenEngineColumnType.NUMERIC, + # ── Date and Time Types ────────────────────── + "date": RustWrenEngineColumnType.DATE, + "datev2": RustWrenEngineColumnType.DATE, + "datetime": RustWrenEngineColumnType.TIMESTAMP, + "datetimev2": RustWrenEngineColumnType.TIMESTAMP, + "timestamp": RustWrenEngineColumnType.TIMESTAMPTZ, + # ── JSON Type ──────────────────────────────── + "json": RustWrenEngineColumnType.JSON, + "jsonb": RustWrenEngineColumnType.JSON, + "variant": RustWrenEngineColumnType.JSON, + # ── Complex Types (map to JSON for compatibility) ─ + "array": RustWrenEngineColumnType.JSON, + "map": RustWrenEngineColumnType.JSON, + "struct": RustWrenEngineColumnType.JSON, + # ── Doris-specific aggregate types (map to VARCHAR) ─ + "hll": RustWrenEngineColumnType.VARCHAR, + "bitmap": RustWrenEngineColumnType.VARCHAR, + "quantile_state": RustWrenEngineColumnType.VARCHAR, + "agg_state": RustWrenEngineColumnType.VARCHAR, +} + + +class DorisMetadata(Metadata): + def __init__(self, connection_info: DorisConnectionInfo): + super().__init__(connection_info) + self.connection = DataSource.doris.get_connection(connection_info) + self.database = connection_info.database.get_secret_value() + + def get_table_list(self) -> list[Table]: + sql = """ + SELECT + c.TABLE_SCHEMA AS table_schema, + c.TABLE_NAME AS table_name, + c.COLUMN_NAME AS column_name, + c.COLUMN_TYPE AS data_type, + c.IS_NULLABLE AS is_nullable, + c.COLUMN_KEY AS column_key, + c.COLUMN_COMMENT AS column_comment, + t.TABLE_COMMENT AS table_comment + FROM + information_schema.COLUMNS c + JOIN + information_schema.TABLES t + ON c.TABLE_SCHEMA = t.TABLE_SCHEMA + AND c.TABLE_NAME = t.TABLE_NAME + WHERE + c.TABLE_SCHEMA NOT IN ('information_schema', '__internal_schema', 'mysql', 'performance_schema', 'sys') + ORDER BY + c.TABLE_SCHEMA, + c.TABLE_NAME, + c.ORDINAL_POSITION; + """ + response = self.connection.sql(sql).to_pandas().to_dict(orient="records") + + unique_tables = {} + for row in response: + # generate unique table name + schema_table = self._format_compact_table_name( + row["table_schema"], row["table_name"] + ) + # init table if not exists + if schema_table not in unique_tables: + unique_tables[schema_table] = Table( + name=schema_table, + description=row["table_comment"], + columns=[], + properties=TableProperties( + schema=row["table_schema"], + catalog="", + table=row["table_name"], + ), + primaryKey=None, + ) + + # table exists, and add column to the table + unique_tables[schema_table].columns.append( + Column( + name=row["column_name"], + type=self._transform_column_type(row["data_type"]), + notNull=row["is_nullable"].lower() == "no", + description=row["column_comment"], + properties=None, + ) + ) + # if column is primary key (Doris Unique Key model) + if row["column_key"] == "UNI" or row["column_key"] == "PRI": + existing_pk = unique_tables[schema_table].primaryKey + if existing_pk: + # Support composite keys + unique_tables[ + schema_table + ].primaryKey = f"{existing_pk},{row['column_name']}" + else: + unique_tables[schema_table].primaryKey = row["column_name"] + return list(unique_tables.values()) + + def get_constraints(self) -> list[Constraint]: + # Doris does not support foreign key constraints. + # Return an empty list as there are no referential constraints. + return [] + + def get_schema_list(self, filter_info=None, limit=None) -> list[Catalog]: + sql = """ + SELECT SCHEMA_NAME + FROM information_schema.SCHEMATA + WHERE SCHEMA_NAME NOT IN ('information_schema', '__internal_schema', 'mysql', 'performance_schema', 'sys') + ORDER BY SCHEMA_NAME + """ + if limit is not None: + try: + validated_limit = int(limit) + except (TypeError, ValueError) as exc: + raise ValueError("limit must be an integer") from exc + if validated_limit < 0: + raise ValueError("limit must be non-negative") + sql += f" LIMIT {validated_limit}" + response = self.connection.sql(sql).to_pandas() + schemas = response["SCHEMA_NAME"].tolist() + # Doris has a flat namespace (no multi-catalog). + # Return a single Catalog entry whose name is the catalog reported by + # the current connection (usually "internal"), with all user databases + # listed as schemas. + try: + catalog_name = ( + self.connection.sql("SELECT CURRENT_CATALOG()").to_pandas().iloc[0, 0] + ) + except Exception: + catalog_name = "internal" + return [Catalog(name=catalog_name, schemas=schemas)] + + def get_version(self) -> str: + return self.connection.sql("SELECT version()").to_pandas().iloc[0, 0] + + def _format_compact_table_name(self, schema: str, table: str): + return f"{schema}.{table}" + + def _format_constraint_name( + self, table_name, column_name, referenced_table_name, referenced_column_name + ): + return f"{table_name}_{column_name}_{referenced_table_name}_{referenced_column_name}" + + def _transform_column_type(self, data_type: str) -> RustWrenEngineColumnType: + """Transform Doris data type to RustWrenEngineColumnType. + + Uses COLUMN_TYPE from information_schema which returns Doris-native + type names (e.g. 'largeint', 'string', 'decimalv3(10,2)') rather + than MySQL-compatible DATA_TYPE (which maps largeint → 'bigint unsigned'). + + Args: + data_type: The Doris COLUMN_TYPE string + + Returns: + The corresponding RustWrenEngineColumnType + """ + # Strip precision/length info: int(11) -> int, decimalv3(10,2) -> decimalv3 + # Also strip angle-bracket generics: ARRAY -> array, MAP -> map + normalized_type = re.sub(r"(<.*>|\(.*\))", "", data_type.strip()).lower() + + # Use the module-level mapping table + mapped_type = DORIS_TYPE_MAPPING.get( + normalized_type, RustWrenEngineColumnType.UNKNOWN + ) + + if mapped_type == RustWrenEngineColumnType.UNKNOWN: + logger.warning(f"Unknown Doris data type: {data_type}") + + return mapped_type diff --git a/ibis-server/app/model/metadata/factory.py b/ibis-server/app/model/metadata/factory.py index 4c8d49c6d..c3fd2b66d 100644 --- a/ibis-server/app/model/metadata/factory.py +++ b/ibis-server/app/model/metadata/factory.py @@ -4,6 +4,7 @@ from app.model.metadata.canner import CannerMetadata from app.model.metadata.clickhouse import ClickHouseMetadata from app.model.metadata.databricks import DatabricksMetadata +from app.model.metadata.doris import DorisMetadata from app.model.metadata.metadata import Metadata from app.model.metadata.mssql import MSSQLMetadata from app.model.metadata.mysql import MySQLMetadata @@ -28,6 +29,7 @@ DataSource.clickhouse: ClickHouseMetadata, DataSource.mssql: MSSQLMetadata, DataSource.mysql: MySQLMetadata, + DataSource.doris: DorisMetadata, DataSource.oracle: OracleMetadata, DataSource.postgres: PostgresMetadata, DataSource.redshift: RedshiftMetadata, diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index 631eea730..ef7a5f35f 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -103,10 +103,10 @@ def _with_session_timezone( ) ) continue - if data_source == DataSource.mysql: + if data_source in {DataSource.mysql, DataSource.doris}: timezone = headers.get(X_WREN_TIMEZONE, "UTC") # TODO: ibis mysql loss the timezone information - # we cast timestamp to timestamp with session timezone for mysql + # we cast timestamp to timestamp with session timezone for mysql/doris fields.append( pa.field( field.name, diff --git a/ibis-server/poetry.lock b/ibis-server/poetry.lock index 679dbdea6..d5bf10959 100644 --- a/ibis-server/poetry.lock +++ b/ibis-server/poetry.lock @@ -8702,4 +8702,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "d79c85212036201c4bd7fb3776e3c99b93f0cff5b22d9ac71abf2ef722bf9829" +content-hash = "d132947b37fe8ae18307790fca7ab704840d76f384b51c52fd989c41e0e1831d" diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index ca77a5cd7..c10df9323 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -83,6 +83,7 @@ pre-commit = "4.2.0" ruff = "0.11.2" trino = ">=0.321,<1" psycopg2 = ">=2.8.4,<3" +pymysql = ">=1.1.0,<2" clickhouse-connect = "0.8.15" asgi-lifespan = "2.1.0" polars = ">=1.32.0" @@ -97,6 +98,7 @@ markers = [ "functions: mark a test as a functions test", "mssql: mark a test as a mssql test", "mysql: mark a test as a mysql test", + "doris: mark a test as a doris test", "oracle: mark a test as a oracle test", "postgres: mark a test as a postgres test", "redshift: mark a test as a redshift test", diff --git a/ibis-server/resources/function_list/doris.csv b/ibis-server/resources/function_list/doris.csv new file mode 100644 index 000000000..c9822aedd --- /dev/null +++ b/ibis-server/resources/function_list/doris.csv @@ -0,0 +1,136 @@ +function_type,name,return_type,param_names,param_types,description +scalar,tan,double,,,"Returns the tangent of a number." +scalar,sqrt,double,,,"Returns the square root of a number." +scalar,trim,varchar,,,"Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string." +scalar,log10,double,,,"Returns the base-10 logarithm of a number." +scalar,nvl,varchar,,,"Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_." +scalar,char_length,bigint,,,"Returns the number of characters in a string." +scalar,radians,double,,,"Converts degrees to radians." +scalar,ln,double,,,"Returns the natural logarithm of a number." +aggregate,bool_or,boolean,,,"Returns true if all non-null input values are true, otherwise false." +window,nth_value,same_as_input,,,"Returns the value evaluated at the nth row of the window frame (counting from 1). Returns NULL if no such row exists." +scalar,regexp_replace,varchar,,,"Replaces substrings in a string that match a regular expression." +scalar,date_format,varchar,,,"Returns a string representation of a date, time, timestamp or duration based on a format." +scalar,substr,varchar,,,"Extracts a substring of a specified number of characters from a specific starting position in a string." +scalar,ascii,int,,,"Returns the Unicode character code of the first character in a string." +scalar,regexp_instr,bigint,,,"Returns the position in a string where the specified occurrence of a POSIX regular expression is located." +scalar,character_length,bigint,,,"Returns the number of characters in a string." +aggregate,var_samp,double,,,"Returns the statistical sample variance of a set of numbers." +aggregate,var_pop,double,,,"Returns the statistical population variance of a set of numbers." +aggregate,array_agg,same_as_input_first_array_element,,,"Returns an array created from the expression elements." +scalar,floor,double,,,"Returns the nearest integer less than or equal to a number." +scalar,strpos,bigint,,,"Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0." +scalar,md5,varchar,,,"Computes an MD5 128-bit checksum for a string expression." +scalar,signum,int,,,"Returns the sign of a number." +scalar,date_trunc,datetime,,,"Truncates a timestamp value to a specified precision." +scalar,lower,varchar,,,"Converts a string to lower-case." +scalar,length,bigint,,,"Returns the number of characters in a string." +scalar,chr,varchar,,,"Returns the character with the specified ASCII code value." +scalar,greatest,same_as_input,,,"Returns the greatest value in a list of expressions." +scalar,reverse,varchar,,,"Reverses the character order of a string." +scalar,cot,double,,,"Returns the cotangent of a number." +scalar,power,double,,,"Returns a base expression raised to the power of an exponent." +aggregate,min,same_as_input,,,"Returns the minimum value in the specified column." +aggregate,bit_xor,bigint,,,"Computes the bitwise exclusive OR of all non-null input values." +aggregate,avg,decimal,,,"Returns the average of numeric values in the specified column." +aggregate,stddev_samp,decimal,,,"Returns the sample standard deviation of a set of numbers." +aggregate,sum,decimal,,,"Returns the sum of all values in the specified column." +aggregate,bit_or,bigint,,,"Computes the bitwise OR of all non-null input values." +window,last_value,same_as_input,,,"Returns value evaluated at the row that is the last row of the window frame." +scalar,round,double,,,"Rounds a number to the nearest integer." +scalar,asin,double,,,"Returns the arc sine or inverse sine of a number." +scalar,upper,varchar,,,"Converts a string to upper-case." +scalar,position,bigint,,,"Returns the starting position of a specified substring in a string." +scalar,atan2,double,,,"Returns the arc tangent or inverse tangent of `expression_y / expression_x`." +scalar,acos,double,,,"Returns the arc cosine or inverse cosine of a number." +scalar,right,varchar,,,"Returns a specified number of characters from the right side of a string." +scalar,left,varchar,,,"Returns a specified number of characters from the left side of a string." +window,first_value,same_as_input,,,"Returns value evaluated at the row that is the first row of the window frame." +scalar,datetrunc,datetime,,,"Truncates a timestamp value to a specified precision." +scalar,current_timestamp,datetime,,,"Returns the current UTC timestamp." +scalar,find_in_set,bigint,,,"Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings." +scalar,to_hex,varchar,,,"Converts an integer to a hexadecimal string." +scalar,octet_length,bigint,,,"Returns the length of a string in bytes." +scalar,nullif,same_as_input,,,"Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_." +scalar,replace,varchar,,,"Replaces all occurrences of a specified substring in a string with a new substring." +scalar,today,date,,,"Returns the current UTC date." +scalar,substring,varchar,,,"Extracts a substring of a specified number of characters from a specific starting position in a string." +aggregate,last_value,same_as_input,,,"Returns the last element in an aggregation group according to the requested ordering." +aggregate,nth_value,same_as_input,,,"Returns the nth value in a group of values." +scalar,instr,bigint,,,"Returns the starting position of a specified substring in a string." +scalar,coalesce,same_as_input,,,"Returns the first of its arguments that is not _null_." +scalar,concat,varchar,,,"Concatenates multiple strings together." +scalar,from_unixtime,datetime,,,"Converts an integer to RFC3339 timestamp format." +scalar,log2,double,,,"Returns the base-2 logarithm of a number." +scalar,ltrim,varchar,,,"Trims the specified trim string from the beginning of a string." +scalar,bit_length,bigint,,,"Returns the bit length of a string." +scalar,abs,same_as_input,,,"Returns the absolute value of a number." +scalar,ceil,double,,,"Returns the nearest integer greater than or equal to a number." +scalar,cos,double,,,"Returns the cosine of a number." +scalar,random,double,,,"Returns a random float value in the range [0, 1)." +scalar,version,varchar,,,"Returns the version of MySQL." +scalar,rpad,varchar,,,"Pads the right side of a string with another string to a specified string length." +scalar,rtrim,varchar,,,"Trims the specified trim string from the end of a string." +aggregate,count,bigint,,,"Returns the number of non-null values in the specified column." +aggregate,bit_and,bigint,,,"Computes the bitwise AND of all non-null input values." +aggregate,stddev_pop,double,,,"Returns the population standard deviation of a set of numbers." +aggregate,first_value,same_as_input,,,"Returns the first element in an aggregation group." +aggregate,max,same_as_input,,,"Returns the maximum value in the specified column." +aggregate,stddev,double,,,"Returns the standard deviation of a set of numbers." +window,percent_rank,double,,,"Returns the percentage rank of the current row within its partition." +window,rank,bigint,,,"Returns the rank of the current row within its partition." +window,dense_rank,bigint,,,"Returns the rank of the current row without gaps." +window,lead,same_as_input,,,"Returns value evaluated at the row that is offset rows after the current row." +window,ntile,int,,,"Integer ranging from 1 to the argument value, dividing the partition as equally as possible." +scalar,atan,double,,,"Returns the arc tangent or inverse tangent of a number." +scalar,uuid,varchar,,,"Returns uuid v4 string value." +scalar,degrees,double,,,"Converts radians to degrees." +scalar,sin,double,,,"Returns the sine of a number." +scalar,now,datetime,,,"Returns the current UTC timestamp." +scalar,log,double,,,"Returns the base-x logarithm of a number." +scalar,least,same_as_input,,,"Returns the smallest value in a list of expressions." +scalar,current_time,time,,,"Returns the current UTC time." +scalar,concat_ws,varchar,,,"Concatenates multiple strings together with a specified separator." +scalar,pi,double,,,"Returns an approximate value of π." +scalar,substring_index,varchar,,,"Returns the substring from str before count occurrences of the delimiter." +scalar,nvl2,varchar,,,"Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_." +scalar,sha256,varchar,,,"Computes the SHA-256 hash of a binary string." +scalar,sha512,varchar,,,"Computes the SHA-512 hash of a binary string." +scalar,ifnull,same_as_input,,,"Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_." +scalar,regexp_like,boolean,,,"Returns true if a regular expression has at least one match in a string." +scalar,exp,double,,,"Returns the base-e exponential of a number." +scalar,lpad,varchar,,,"Pads the left side of a string with another string to a specified string length." +scalar,repeat,varchar,,,"Returns a string with an input string repeated a specified number." +scalar,to_char,varchar,,,"Returns a string representation of a date, time, timestamp or duration based on a format." +scalar,pow,double,,,"Returns a base expression raised to the power of an exponent." +scalar,current_date,date,,,"Returns the current UTC date." +aggregate,string_agg,varchar,,,"Concatenates the values of string expressions with separator values." +aggregate,bool_and,boolean,,,"Returns true if all non-null input values are true, otherwise false." +window,lag,same_as_input,,,"Returns value evaluated at the row that is offset rows before the current row." +window,cume_dist,double,,,"Relative rank of the current row." +window,row_number,bigint,,,"Number of the current row within its partition, counting from 1." +scalar,if,bool,,"boolean,any,any","Returns one value if a condition is TRUE, or another value if a condition is FALSE" +scalar,ceiling,int,,"decimal","Returns the smallest integer value greater than or equal to a number" +scalar,datediff,int,,"date,date","Returns the number of days between two dates" +scalar,timestampdiff,int,,"varchar,datetime,datetime","Returns the difference between two datetime expressions" +scalar,inet_aton,int,,"varchar","Converts an IPv4 address to numeric value" +scalar,inet_ntoa,varchar,,"int","Converts numeric value to IPv4 address" +scalar,format,varchar,,"decimal,int","Formats number to specified decimal places and adds thousand separators" +scalar,hex,varchar,,"decimal_or_string","Returns hexadecimal representation of a decimal or string value" +scalar,unhex,varchar,,"varchar","Converts hexadecimal value to string" +scalar,lcase,varchar,,"varchar","Synonym for LOWER()" +scalar,quote,varchar,,"varchar","Escapes string and adds single quotes" +scalar,soundex,varchar,,"varchar","Returns soundex string of given string" +scalar,space,varchar,,"int","Returns string of specified number of spaces" +scalar,truncate,decimal,,"decimal,int","Truncates number to specified number of decimal places" +scalar,weekday,int,,"date","Returns weekday index (0=Monday, 6=Sunday)" +scalar,yearweek,int,,"date","Returns year and week number" +scalar,dayname,varchar,,"date","Returns name of weekday" +scalar,monthname,varchar,,"date","Returns name of month" +scalar,quarter,int,,"date","Returns quarter from date (1 to 4)" +scalar,week,int,,"date","Returns week number" +aggregate,group_concat,varchar,,"any","Returns a concatenated string from a group" +aggregate,std,decimal,,"any","Returns the population standard deviation" +aggregate,variance,decimal,,"any","Returns the population variance" +aggregate,json_arrayagg,json,,"any","Aggregates result set as JSON array" +aggregate,json_objectagg,json,,"varchar,any","Aggregates result set as JSON object" diff --git a/ibis-server/resources/knowledge/dialects/doris.txt b/ibis-server/resources/knowledge/dialects/doris.txt new file mode 100644 index 000000000..4ed7bf6b4 --- /dev/null +++ b/ibis-server/resources/knowledge/dialects/doris.txt @@ -0,0 +1,37 @@ +### WHEN USING DORIS AS THE BACKEND DATABASE ### +Here are the tips when generating SQL queries for Wren Engine with Apache Doris as the backend database. + +- Apache Doris uses MySQL-compatible SQL syntax, but there are some differences to be aware of. +- Doris is an analytical (OLAP) database. It is optimized for large-scale aggregation queries rather than transactional workloads. +- Doris uses backtick (`) for quoting identifiers, same as MySQL. + - Example: SELECT `column_name` FROM `table_name` +- `STRING` is an alias for `VARCHAR(65533)` in Doris. +- Doris supports `LARGEINT` (128-bit integer) type which is not available in MySQL. +- Doris has versioned types: `DATEV2` (equivalent to `DATE`), `DATETIMEV2` (equivalent to `DATETIME` with fractional seconds), and `DECIMALV3` (high-precision decimal). + - Use `DATE`, `DATETIME`, and `DECIMAL` in queries — Doris handles version mapping internally. +- For date/time operations, Doris supports standard MySQL functions: + - `DATE_ADD(date, INTERVAL expr unit)` and `DATE_SUB(date, INTERVAL expr unit)` + - `DATEDIFF(expr1, expr2)` returns difference in days + - `TIMESTAMPDIFF(unit, datetime_expr1, datetime_expr2)` for precise differences + - `DATE_FORMAT(date, format)` for formatting dates + - `STR_TO_DATE(str, format)` for parsing strings to dates +- `BITMAP`, `HLL`, and `QUANTILE_STATE` are special aggregate types in Doris used for approximate distinct counting and quantile estimation. + - These types cannot be used in regular SELECT queries directly. + - Use aggregate functions like `BITMAP_UNION_COUNT`, `HLL_UNION_AGG`, `QUANTILE_PERCENT` to query them. +- Doris does not support foreign key constraints. Do not generate queries that rely on foreign key relationships from the database schema. + - Relationships between tables are defined in the Wren MDL (Manifest Definition Language) and are handled by the Wren Engine. +- Doris supports the `GROUP BY` clause with column aliases defined in the SELECT clause. + - Example: `SELECT column1 AS col1_alias, COUNT(*) FROM table GROUP BY col1_alias` is valid in Doris. +- Doris supports window functions with the standard `OVER (PARTITION BY ... ORDER BY ...)` syntax. +- For string concatenation, use `CONCAT()` function or `||` operator. + - Example: `CONCAT(col1, '-', col2)` or `col1 || '-' || col2` +- Doris supports `CASE WHEN ... THEN ... ELSE ... END` expressions. +- For NULL handling, use `IFNULL(expr, default)`, `COALESCE(expr1, expr2, ...)`, or `NULLIF(expr1, expr2)`. +- Doris supports `LIMIT offset, count` syntax for pagination. + - Example: `SELECT * FROM table LIMIT 10, 20` (skip 10 rows, return 20 rows) + - Also supports `LIMIT count OFFSET offset` syntax. +- `ARRAY`, `MAP`, `STRUCT`, and `VARIANT` (semi-structured JSON) types are supported in newer Doris versions. + - Use `ARRAY` functions like `ARRAY_CONTAINS`, `ARRAY_SIZE`, `EXPLODE` for array operations. + - Use `JSON_EXTRACT`, `JSON_OBJECT`, `JSON_ARRAY` for JSON/VARIANT operations. +- `DATETIME` and `TIMESTAMP` are equivalent in Doris and do NOT carry time zone information. + - If timezone-aware operations are needed, handle timezone conversion explicitly using `CONVERT_TZ(datetime, from_tz, to_tz)`. diff --git a/ibis-server/tests/routers/v3/connector/doris/__init__.py b/ibis-server/tests/routers/v3/connector/doris/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ibis-server/tests/routers/v3/connector/doris/conftest.py b/ibis-server/tests/routers/v3/connector/doris/conftest.py new file mode 100644 index 000000000..ea9a70542 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/doris/conftest.py @@ -0,0 +1,136 @@ +import pathlib +import time + +import pandas as pd +import pymysql +import pytest + +from app.config import get_config +from tests.conftest import file_path + +pytestmark = pytest.mark.doris + +base_url = "/v3/connector/doris" + + +def pytest_collection_modifyitems(items): + current_file_dir = pathlib.Path(__file__).resolve().parent + for item in items: + if pathlib.Path(item.fspath).is_relative_to(current_file_dir): + item.add_marker(pytestmark) + + +function_list_path = file_path("../resources/function_list") + + +DORIS_HOST = "127.0.0.1" +DORIS_PORT = 9030 +DORIS_USER = "root" +DORIS_PASSWORD = "" +DORIS_DATABASE = "wren_test" + + +@pytest.fixture(scope="session") +def doris(request): + """Connect to Doris and load TPC-H test data.""" + conn = pymysql.connect( + host=DORIS_HOST, + port=DORIS_PORT, + user=DORIS_USER, + password=DORIS_PASSWORD, + autocommit=True, + ) + cursor = conn.cursor() + + # Create test database + cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{DORIS_DATABASE}`") + cursor.execute(f"USE `{DORIS_DATABASE}`") + + # Ensure clean state + cursor.execute("DROP TABLE IF EXISTS orders") + + # Create orders table with Doris DDL + cursor.execute( + """ + CREATE TABLE orders ( + o_orderkey INT, + o_custkey INT, + o_orderstatus VARCHAR(1), + o_totalprice DECIMAL(15, 2), + o_orderdate DATE, + o_orderpriority VARCHAR(15), + o_clerk VARCHAR(15), + o_shippriority INT, + o_comment VARCHAR(79) + ) + DISTRIBUTED BY HASH(o_orderkey) BUCKETS 1 + PROPERTIES ("replication_num" = "1") + """ + ) + + # Load test data from parquet + orders_pdf = pd.read_parquet(file_path("resource/tpch/data/orders.parquet")) + + # Convert date column to string for pymysql INSERT + if "o_orderdate" in orders_pdf.columns: + orders_pdf["o_orderdate"] = orders_pdf["o_orderdate"].astype(str) + + # Handle NaN values + orders_pdf = orders_pdf.where(orders_pdf.notna(), None) + + # Batch insert + columns = list(orders_pdf.columns) + col_str = ", ".join(columns) + placeholders = ", ".join(["%s"] * len(columns)) + insert_sql = f"INSERT INTO orders ({col_str}) VALUES ({placeholders})" + + data = [tuple(row) for row in orders_pdf.values] + + batch_size = 500 + for i in range(0, len(data), batch_size): + batch = data[i : i + batch_size] + cursor.executemany(insert_sql, batch) + + cursor.close() + conn.close() + + # Wait for Doris to make data visible + time.sleep(3) + + def cleanup(): + try: + c = pymysql.connect( + host=DORIS_HOST, + port=DORIS_PORT, + user=DORIS_USER, + password=DORIS_PASSWORD, + database=DORIS_DATABASE, + autocommit=True, + ) + cur = c.cursor() + cur.execute("DROP TABLE IF EXISTS orders") + cur.close() + c.close() + except Exception: + pass + + request.addfinalizer(cleanup) + + +@pytest.fixture(scope="module") +def connection_info(doris) -> dict[str, str]: + return { + "host": DORIS_HOST, + "port": str(DORIS_PORT), + "user": DORIS_USER, + "password": DORIS_PASSWORD, + "database": DORIS_DATABASE, + } + + +@pytest.fixture(autouse=True) +def set_remote_function_list_path(): + config = get_config() + config.set_remote_function_list_path(function_list_path) + yield + config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/doris/docker-compose.yml b/ibis-server/tests/routers/v3/connector/doris/docker-compose.yml new file mode 100644 index 000000000..bc19081be --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/doris/docker-compose.yml @@ -0,0 +1,18 @@ +services: + doris: + image: apache/doris:4.0.3-all-slim + container_name: doris-test + hostname: doris + privileged: true + ports: + - "9030:9030" # FE MySQL protocol port + - "8030:8030" # FE HTTP port (Web UI) + - "8040:8040" # BE HTTP port + environment: + - SKIP_CHECK_ULIMIT=true + healthcheck: + test: ["CMD-SHELL", "mysql -h 127.0.0.1 -P 9030 -uroot -e 'SHOW BACKENDS\\G' | grep -q 'Alive: true' || exit 1"] + interval: 15s + timeout: 10s + retries: 20 + start_period: 180s diff --git a/ibis-server/tests/routers/v3/connector/doris/test_functions.py b/ibis-server/tests/routers/v3/connector/doris/test_functions.py new file mode 100644 index 000000000..585f1f048 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/doris/test_functions.py @@ -0,0 +1,112 @@ +import base64 + +import orjson +import pytest + +from app.config import get_config +from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path +from tests.routers.v3.connector.doris.conftest import base_url + +manifest = { + "dataSource": "doris", + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "wren_test", + "table": "orders", + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_totalprice", "type": "float"}, + {"name": "o_orderdate", "type": "date"}, + ], + }, + ], +} + + +function_list_path = file_path("../resources/function_list") + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +@pytest.fixture(autouse=True) +def set_remote_function_list_path(): + config = get_config() + config.set_remote_function_list_path(function_list_path) + yield + config.set_remote_function_list_path(None) + + +async def test_function_list(client): + config = get_config() + + config.set_remote_function_list_path(None) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + + config.set_remote_function_list_path(function_list_path) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + # Doris functions from doris.csv are added to DataFusion built-ins; + # functions whose names already exist in DataFusion are deduplicated. + assert len(result) > DATAFUSION_FUNCTION_COUNT + # Verify a Doris-specific function is present + the_func = next(filter(lambda x: x["name"] == "lcase", result)) + assert the_func == { + "name": "lcase", + "description": "Synonym for LOWER()", + "function_type": "scalar", + "param_names": None, + "param_types": None, + "return_type": None, + } + + config.set_remote_function_list_path(None) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + + +async def test_scalar_function(client, manifest_str: str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT ABS(-1) AS col", + }, + ) + assert response.status_code == 200 + result = response.json() + assert result["columns"] == ["col"] + assert result["data"] == [[1]] + assert result["dtypes"]["col"] in ("int16", "int32") + + +async def test_aggregate_function(client, manifest_str: str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT COUNT(*) AS col FROM (SELECT 1) AS temp_table", + }, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[1]], + "dtypes": {"col": "int64"}, + } diff --git a/ibis-server/tests/routers/v3/connector/doris/test_metadata.py b/ibis-server/tests/routers/v3/connector/doris/test_metadata.py new file mode 100644 index 000000000..a60b82315 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/doris/test_metadata.py @@ -0,0 +1,77 @@ +import pytest + +from tests.routers.v3.connector.doris.conftest import base_url + +v2_base_url = "/v2/connector/doris" + + +async def test_metadata_list_tables(client, connection_info): + response = await client.post( + url=f"{v2_base_url}/metadata/tables", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + + tables = response.json() + assert len(tables) > 0 + + # Find the orders table in wren_test schema + result = next( + filter(lambda x: x["name"] == "wren_test.orders", tables), + None, + ) + assert result is not None + assert result["name"] == "wren_test.orders" + assert result["properties"] is not None + assert result["properties"]["schema"] == "wren_test" + assert result["properties"]["table"] == "orders" + assert len(result["columns"]) > 0 + + # Check a specific column (o_orderkey) + orderkey_col = next( + filter(lambda col: col["name"] == "o_orderkey", result["columns"]), None + ) + assert orderkey_col is not None + assert orderkey_col["type"] in ["INTEGER", "INT"] + assert orderkey_col["nestedColumns"] is None + + # Check decimal column (o_totalprice) + price_col = next( + filter(lambda col: col["name"] == "o_totalprice", result["columns"]), None + ) + assert price_col is not None + assert price_col["type"] in ["DECIMAL", "DECIMALV3"] + + # Check date column (o_orderdate) + date_col = next( + filter(lambda col: col["name"] == "o_orderdate", result["columns"]), None + ) + assert date_col is not None + assert date_col["type"] in ["DATE", "DATEV2"] + + +async def test_metadata_list_constraints(client, connection_info): + response = await client.post( + url=f"{v2_base_url}/metadata/constraints", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + + constraints = response.json() + + # Doris does not support foreign key constraints, expect empty list + assert constraints == [] + + +async def test_metadata_db_version(client, connection_info): + response = await client.post( + url=f"{v2_base_url}/metadata/version", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + assert response.text is not None + + # Doris returns MySQL-compatible version string (e.g. "5.7.99") + version_str = response.text + assert version_str is not None + assert len(version_str) > 0 diff --git a/ibis-server/tests/routers/v3/connector/doris/test_query.py b/ibis-server/tests/routers/v3/connector/doris/test_query.py new file mode 100644 index 000000000..4c1a6b173 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/doris/test_query.py @@ -0,0 +1,202 @@ +import base64 + +import orjson +import pytest + +from tests.routers.v3.connector.doris.conftest import base_url + +manifest = { + "dataSource": "doris", + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "wren_test", + "table": "orders", + }, + "columns": [ + {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, + {"name": "custkey", "expression": "o_custkey", "type": "integer"}, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "varchar", + }, + {"name": "totalprice", "expression": "o_totalprice", "type": "float"}, + {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, + { + "name": "order_cust_key", + "expression": "concat(o_orderkey, '_', o_custkey)", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "cast('2024-01-01 23:59:59' as datetime)", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "cast('2024-01-01 23:59:59' as datetime)", + "type": "timestamp", + }, + { + "name": "test_null_time", + "expression": "cast(NULL as datetime)", + "type": "timestamp", + }, + ], + }, + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_query(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders ORDER BY orderkey LIMIT 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0][0] == 1 # orderkey + assert result["data"][0][1] == 370 # custkey + assert result["data"][0][2] == "O" # orderstatus + + +async def test_query_with_limit(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + params={"limit": 1}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + response = await client.post( + url=f"{base_url}/query", + params={"limit": 1}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 10", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + +async def test_query_with_invalid_manifest_str(client, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": "xxx", + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 422 + + +async def test_query_without_manifest(client, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_without_sql(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={"connectionInfo": connection_info, "manifestStr": manifest_str}, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_with_dry_run(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 204 + + +async def test_query_with_dry_run_and_invalid_sql( + client, manifest_str, connection_info +): + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM X", + }, + ) + assert response.status_code == 422 + assert response.text is not None + + +async def test_no_transaction_wrapping(client, connection_info): + """Verify Doris queries execute without BEGIN/ROLLBACK wrapping. + + Doris is an OLAP engine and rejects transactional SELECT statements. + The autocommit fix must be in place. + """ + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": base64.b64encode( + orjson.dumps( + { + "dataSource": "doris", + "catalog": "c", + "schema": "s", + "models": [], + } + ) + ).decode("utf-8"), + "sql": "SELECT version()", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + # Doris returns MySQL-compatible version (e.g. "5.7.99") + version_str = str(result["data"][0][0]) + assert len(version_str) > 0 diff --git a/ibis-server/tools/query_local_run.py b/ibis-server/tools/query_local_run.py index 7b80516c4..dec5fae47 100644 --- a/ibis-server/tools/query_local_run.py +++ b/ibis-server/tools/query_local_run.py @@ -15,7 +15,7 @@ import json import os from app.custom_sqlglot.dialects.wren import Wren -from app.model import BigQueryDatasetConnectionInfo, MSSqlConnectionInfo, MySqlConnectionInfo, OracleConnectionInfo, PostgresConnectionInfo, SnowflakeConnectionInfo +from app.model import BigQueryDatasetConnectionInfo, DorisConnectionInfo, MSSqlConnectionInfo, MySqlConnectionInfo, OracleConnectionInfo, PostgresConnectionInfo, SnowflakeConnectionInfo from app.model.connector import BigQueryConnector from app.util import to_json import sqlglot @@ -96,6 +96,9 @@ elif data_source == "mysql": connection_info = MySqlConnectionInfo.model_validate_json(json.dumps(connection_info)) connection = DataSourceExtension.get_mysql_connection(connection_info) +elif data_source == "doris": + connection_info = DorisConnectionInfo.model_validate_json(json.dumps(connection_info)) + connection = DataSourceExtension.get_doris_connection(connection_info) elif data_source == "postgres": connection_info = PostgresConnectionInfo.model_validate_json(json.dumps(connection_info)) connection = DataSourceExtension.get_postgres_connection(connection_info) diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 51e94c5a9..5efcecbf1 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -85,6 +85,8 @@ pub fn data_source(python_binding: proc_macro::TokenStream) -> proc_macro::Token MSSQL, #[serde(alias = "mysql")] MySQL, + #[serde(alias = "doris")] + Doris, #[serde(alias = "postgres")] Postgres, #[serde(alias = "snowflake")] diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index beb841b04..fed827a7c 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -136,6 +136,7 @@ impl Display for DataSource { DataSource::Trino => write!(f, "TRINO"), DataSource::MSSQL => write!(f, "MSSQL"), DataSource::MySQL => write!(f, "MYSQL"), + DataSource::Doris => write!(f, "DORIS"), DataSource::Postgres => write!(f, "POSTGRES"), DataSource::Snowflake => write!(f, "SNOWFLAKE"), DataSource::Datafusion => write!(f, "DATAFUSION"), @@ -164,6 +165,7 @@ impl FromStr for DataSource { "TRINO" => Ok(DataSource::Trino), "MSSQL" => Ok(DataSource::MSSQL), "MYSQL" => Ok(DataSource::MySQL), + "DORIS" => Ok(DataSource::Doris), "POSTGRES" => Ok(DataSource::Postgres), "SNOWFLAKE" => Ok(DataSource::Snowflake), "DATAFUSION" => Ok(DataSource::Datafusion), diff --git a/wren-core/core/src/mdl/dialect/inner_dialect.rs b/wren-core/core/src/mdl/dialect/inner_dialect.rs index 9fae7167f..e7e90961d 100644 --- a/wren-core/core/src/mdl/dialect/inner_dialect.rs +++ b/wren-core/core/src/mdl/dialect/inner_dialect.rs @@ -121,6 +121,7 @@ pub trait InnerDialect: Send + Sync { pub fn get_inner_dialect(data_source: &DataSource) -> Box { match data_source { DataSource::MySQL => Box::new(MySQLDialect {}), + DataSource::Doris => Box::new(MySQLDialect {}), DataSource::BigQuery => Box::new(BigQueryDialect {}), DataSource::Oracle => Box::new(OracleDialect {}), DataSource::MSSQL => Box::new(MsSqlDialect {}),