Skip to content

Commit 8184d59

Browse files
authored
feat: BI-6553 Add Trino timeouts (#1221)
1 parent 5ae45da commit 8184d59

File tree

5 files changed

+58
-33
lines changed

5 files changed

+58
-33
lines changed

lib/dl_connector_trino/dl_connector_trino/core/adapters.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@
6969

7070

7171
class CustomHTTPAdapter(HTTPAdapter):
72-
def __init__(self, ssl_ca: str, *args: Any, **kwargs: Any) -> None:
72+
"""
73+
This custom adapter is here to create an SSL context with a custom CA certificate provided as a string instead of a file path.
74+
"""
75+
76+
def __init__(self, ssl_ca: str | None = None, *args: Any, **kwargs: Any) -> None:
7377
self.ssl_ca = ssl_ca
7478
super().__init__(*args, **kwargs)
7579

7680
def init_poolmanager(self, connections: int, maxsize: int, block: bool = False, **pool_kwargs: Any) -> None:
77-
# Use a secure context with the provided SSL CA
78-
context = ssl.create_default_context(cadata=self.ssl_ca)
81+
if self.ssl_ca is None:
82+
context = ssl.create_default_context(capath=get_root_certificates_path())
83+
else:
84+
context = ssl.create_default_context(cadata=self.ssl_ca)
7985
super().init_poolmanager(connections, maxsize, block, ssl_context=context, **pool_kwargs)
8086

8187

@@ -122,10 +128,6 @@ class TrinoDefaultAdapter(BaseClassicAdapter[TrinoConnTargetDTO]):
122128
EXTRA_EXC_CLS = (sa_exc.DBAPIError,)
123129

124130
def get_conn_line(self, db_name: str | None = None, params: dict[str, Any] | None = None) -> str:
125-
# We do not expect to transfer any additional parameters when creating the engine.
126-
# This check is needed to track if it still passed.
127-
assert params is None
128-
129131
params = params or {}
130132
return trino_url(
131133
host=self._target_dto.host,
@@ -136,11 +138,24 @@ def get_conn_line(self, db_name: str | None = None, params: dict[str, Any] | Non
136138
**params,
137139
)
138140

141+
def _get_http_session(self) -> requests.Session:
142+
session = requests.Session()
143+
adapter = CustomHTTPAdapter(ssl_ca=self._target_dto.ssl_ca)
144+
session.mount("http://", adapter)
145+
session.mount("https://", adapter)
146+
return session
147+
139148
def get_connect_args(self) -> dict[str, Any]:
140-
args: dict[str, Any] = {
141-
**super().get_connect_args(),
142-
"legacy_primitive_types": True,
143-
}
149+
timeout = (
150+
self._target_dto.connect_timeout,
151+
self._target_dto.total_timeout,
152+
)
153+
args: dict[str, Any] = super().get_connect_args() | dict(
154+
http_scheme="https" if self._target_dto.ssl_enable else "http",
155+
http_session=self._get_http_session(),
156+
legacy_primitive_types=True,
157+
request_timeout=timeout,
158+
)
144159
if self._target_dto.auth_type is TrinoAuthType.none:
145160
pass
146161
elif self._target_dto.auth_type is TrinoAuthType.password:
@@ -150,18 +165,6 @@ def get_connect_args(self) -> dict[str, Any]:
150165
else:
151166
raise NotImplementedError(f"{self._target_dto.auth_type.name} authentication is not supported yet")
152167

153-
if not self._target_dto.ssl_enable:
154-
args["http_scheme"] = "http"
155-
return args
156-
157-
args["http_scheme"] = "https"
158-
if self._target_dto.ssl_ca:
159-
session = requests.Session()
160-
session.mount("https://", CustomHTTPAdapter(self._target_dto.ssl_ca))
161-
args["http_session"] = session
162-
else:
163-
args["verify"] = get_root_certificates_path()
164-
165168
return args
166169

167170
def execute_by_steps(self, db_adapter_query: DBAdapterQuery) -> Generator[ExecutionStep, None, None]:
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import attr
2+
3+
from dl_core.connection_models.conn_options import ConnectOptions
4+
5+
6+
@attr.s(frozen=True, hash=True)
7+
class TrinoConnectOptions(ConnectOptions):
8+
connect_timeout: int | None = attr.ib(default=None)
9+
total_timeout: int | None = attr.ib(default=None)

lib/dl_connector_trino/dl_connector_trino/core/connection_executors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dl_core.connection_executors.async_sa_executors import DefaultSqlAlchemyConnExecutor
44

55
from dl_connector_trino.core.adapters import TrinoDefaultAdapter
6+
from dl_connector_trino.core.conn_options import TrinoConnectOptions
67
from dl_connector_trino.core.dto import TrinoConnDTO
78
from dl_connector_trino.core.target_dto import TrinoConnTargetDTO
89

@@ -11,6 +12,7 @@
1112
class TrinoConnExecutor(DefaultSqlAlchemyConnExecutor[TrinoDefaultAdapter]):
1213
TARGET_ADAPTER_CLS = TrinoDefaultAdapter
1314
_conn_dto: TrinoConnDTO = attr.ib()
15+
_conn_options: TrinoConnectOptions = attr.ib()
1416

1517
async def _make_target_conn_dto_pool(self) -> list[TrinoConnTargetDTO]:
1618
return [
@@ -26,5 +28,7 @@ async def _make_target_conn_dto_pool(self) -> list[TrinoConnTargetDTO]:
2628
jwt=self._conn_dto.jwt,
2729
ssl_enable=self._conn_dto.ssl_enable,
2830
ssl_ca=self._conn_dto.ssl_ca,
31+
connect_timeout=self._conn_options.connect_timeout,
32+
total_timeout=self._conn_options.total_timeout,
2933
)
3034
]

lib/dl_connector_trino/dl_connector_trino/core/target_dto.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
import attr
42
from typing_extensions import Self
53

@@ -9,18 +7,20 @@
97
from dl_connector_trino.core.constants import TrinoAuthType
108

119

12-
@attr.s(frozen=True)
10+
@attr.s(frozen=True, kw_only=True)
1311
class TrinoConnTargetDTO(ConnTargetDTO):
1412
host: str = attr.ib()
1513
port: int = attr.ib()
1614
username: str = attr.ib()
17-
auth_type: TrinoAuthType = attr.ib(kw_only=True)
18-
password: Optional[str] = attr.ib(repr=False, kw_only=True, default=None)
19-
jwt: Optional[str] = attr.ib(repr=False, kw_only=True, default=None)
20-
ssl_enable: bool = attr.ib(kw_only=True, default=False)
21-
ssl_ca: Optional[str] = attr.ib(kw_only=True, default=None)
22-
23-
def get_effective_host(self) -> Optional[str]:
15+
auth_type: TrinoAuthType = attr.ib()
16+
password: str | None = attr.ib(repr=False, default=None)
17+
jwt: str | None = attr.ib(repr=False, default=None)
18+
ssl_enable: bool = attr.ib(default=False)
19+
ssl_ca: str | None = attr.ib(default=None)
20+
connect_timeout: int | None = attr.ib(default=None)
21+
total_timeout: int | None = attr.ib(default=None)
22+
23+
def get_effective_host(self) -> str | None:
2424
return self.host
2525

2626
@classmethod

lib/dl_connector_trino/dl_connector_trino/core/us_connection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from dl_i18n.localizer_base import Localizer
2222

2323
from dl_connector_trino.api.i18n.localizer import Translatable
24+
from dl_connector_trino.core.conn_options import TrinoConnectOptions
2425
from dl_connector_trino.core.constants import (
2526
CONNECTION_TYPE_TRINO,
2627
SOURCE_TYPE_TRINO_SUBSELECT,
@@ -81,6 +82,14 @@ class DataModel(ConnectionSQL.DataModel):
8182
jwt: str | None = attr.ib(repr=secrepr, default=None)
8283
listing_sources: ListingSources = attr.ib()
8384

85+
def get_conn_options(self) -> TrinoConnectOptions:
86+
base = super().get_conn_options()
87+
return base.to_subclass(
88+
TrinoConnectOptions,
89+
connect_timeout=1,
90+
total_timeout=80,
91+
)
92+
8493
def get_data_source_template_templates(self, localizer: Localizer) -> list[DataSourceTemplate]:
8594
result: list[DataSourceTemplate] = []
8695

0 commit comments

Comments
 (0)