diff --git a/src/postgrest/src/postgrest/_async/request_builder.py b/src/postgrest/src/postgrest/_async/request_builder.py index a575edd7..3c996922 100644 --- a/src/postgrest/src/postgrest/_async/request_builder.py +++ b/src/postgrest/src/postgrest/_async/request_builder.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload from httpx import AsyncClient, BasicAuth, Headers, QueryParams, Response from pydantic import ValidationError @@ -32,7 +32,7 @@ class AsyncQueryRequestBuilder: def __init__(self, request: ReqConfig): self.request = request - async def execute(self) -> APIResponse | str: + async def execute(self) -> APIResponse: """Execute the query. .. tip:: @@ -47,17 +47,6 @@ async def execute(self) -> APIResponse | str: r = await self.request.send() try: if r.is_success: - if self.request.http_method != "HEAD": - body = r.text - if self.request.headers.get("Accept") == "text/csv": - return body - if self.request.headers.get( - "Accept" - ) and "application/vnd.pgrst.plan" in self.request.headers.get( - "Accept" - ): - if "+json" not in self.request.headers.get("Accept"): - return body return APIResponse.from_http_request_response(r) else: json_obj = model_validate_json(APIErrorFromJSON, r.content) @@ -95,6 +84,22 @@ async def execute(self) -> SingleAPIResponse: raise APIError(generate_default_error_message(r)) +class AsyncExplainRequestBuilder: + def __init__(self, request: ReqConfig): + self.request = request + + async def execute(self) -> str: + r = await self.request.send() + try: + if r.is_success: + return r.text + else: + json_obj = model_validate_json(APIErrorFromJSON, r.content) + raise APIError(dict(json_obj)) + except ValidationError as e: + raise APIError(generate_default_error_message(r)) + + class AsyncMaybeSingleRequestBuilder: def __init__(self, request: ReqConfig): self.request = request @@ -176,6 +181,52 @@ def csv(self) -> AsyncSingleRequestBuilder: self.request.headers["Accept"] = "text/csv" return AsyncSingleRequestBuilder(self.request) + @overload + def explain( + self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + format: Literal["text"] = "text", + ) -> AsyncExplainRequestBuilder: ... + + @overload + def explain( + self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + *, + format: Literal["json"], + ) -> AsyncSingleRequestBuilder: ... + + def explain( + self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + format: Literal["text", "json"] = "text", + ) -> AsyncExplainRequestBuilder | AsyncSingleRequestBuilder: + options = [ + key + for key, value in locals().items() + if key not in ["self", "format"] and value + ] + options_str = "|".join(options) + self.request.headers["Accept"] = ( + f"application/vnd.pgrst.plan+{format}; options={options_str}" + ) + if format == "text": + return AsyncExplainRequestBuilder(self.request) + else: + return AsyncSingleRequestBuilder(self.request) + class AsyncRequestBuilder: # def __init__( diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index e7e2bd9e..a5340403 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload from httpx import BasicAuth, Client, Headers, QueryParams, Response from pydantic import ValidationError @@ -32,7 +32,7 @@ class SyncQueryRequestBuilder: def __init__(self, request: ReqConfig): self.request = request - def execute(self) -> APIResponse | str: + def execute(self) -> APIResponse: """Execute the query. .. tip:: @@ -47,17 +47,6 @@ def execute(self) -> APIResponse | str: r = self.request.send() try: if r.is_success: - if self.request.http_method != "HEAD": - body = r.text - if self.request.headers.get("Accept") == "text/csv": - return body - if self.request.headers.get( - "Accept" - ) and "application/vnd.pgrst.plan" in self.request.headers.get( - "Accept" - ): - if "+json" not in self.request.headers.get("Accept"): - return body return APIResponse.from_http_request_response(r) else: json_obj = model_validate_json(APIErrorFromJSON, r.content) @@ -95,6 +84,22 @@ def execute(self) -> SingleAPIResponse: raise APIError(generate_default_error_message(r)) +class SyncExplainRequestBuilder: + def __init__(self, request: ReqConfig): + self.request = request + + def execute(self) -> str: + r = self.request.send() + try: + if r.is_success: + return r.text + else: + json_obj = model_validate_json(APIErrorFromJSON, r.content) + raise APIError(dict(json_obj)) + except ValidationError as e: + raise APIError(generate_default_error_message(r)) + + class SyncMaybeSingleRequestBuilder: def __init__(self, request: ReqConfig): self.request = request @@ -176,6 +181,52 @@ def csv(self) -> SyncSingleRequestBuilder: self.request.headers["Accept"] = "text/csv" return SyncSingleRequestBuilder(self.request) + @overload + def explain( + self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + format: Literal["text"] = "text", + ) -> SyncExplainRequestBuilder: ... + + @overload + def explain( + self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + *, + format: Literal["json"], + ) -> SyncSingleRequestBuilder: ... + + def explain( + self, + analyze: bool = False, + verbose: bool = False, + settings: bool = False, + buffers: bool = False, + wal: bool = False, + format: Literal["text", "json"] = "text", + ) -> SyncExplainRequestBuilder | SyncSingleRequestBuilder: + options = [ + key + for key, value in locals().items() + if key not in ["self", "format"] and value + ] + options_str = "|".join(options) + self.request.headers["Accept"] = ( + f"application/vnd.pgrst.plan+{format}; options={options_str}" + ) + if format == "text": + return SyncExplainRequestBuilder(self.request) + else: + return SyncSingleRequestBuilder(self.request) + class SyncRequestBuilder: # def __init__( diff --git a/src/postgrest/src/postgrest/base_request_builder.py b/src/postgrest/src/postgrest/base_request_builder.py index f92a6158..49e84005 100644 --- a/src/postgrest/src/postgrest/base_request_builder.py +++ b/src/postgrest/src/postgrest/base_request_builder.py @@ -557,26 +557,6 @@ def max_affected(self: Self, value: int) -> Self: class BaseSelectRequestBuilder(BaseFilterRequestBuilder[C]): - def explain( - self: Self, - analyze: bool = False, - verbose: bool = False, - settings: bool = False, - buffers: bool = False, - wal: bool = False, - format: Literal["text", "json"] = "text", - ) -> Self: - options = [ - key - for key, value in locals().items() - if key not in ["self", "format"] and value - ] - options_str = "|".join(options) - self.request.headers["Accept"] = ( - f"application/vnd.pgrst.plan+{format}; options={options_str}" - ) - return self - def order( self: Self, column: str,