Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions src/postgrest/src/postgrest/_async/request_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Generic, Optional, TypeVar, Union
from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload

from httpx import AsyncClient, BasicAuth, Headers, QueryParams, Response
from pydantic import ValidationError
Expand Down Expand Up @@ -32,7 +32,7 @@ class AsyncQueryRequestBuilder:
def __init__(self, request: ReqConfig):
self.request = request

async def execute(self) -> APIResponse | str:
async def execute(self) -> APIResponse:
"""Execute the query.

.. tip::
Expand All @@ -47,17 +47,6 @@ async def execute(self) -> APIResponse | str:
r = await self.request.send()
try:
if r.is_success:
if self.request.http_method != "HEAD":
body = r.text
if self.request.headers.get("Accept") == "text/csv":
return body
if self.request.headers.get(
"Accept"
) and "application/vnd.pgrst.plan" in self.request.headers.get(
"Accept"
):
if "+json" not in self.request.headers.get("Accept"):
return body
return APIResponse.from_http_request_response(r)
else:
json_obj = model_validate_json(APIErrorFromJSON, r.content)
Expand Down Expand Up @@ -95,6 +84,22 @@ async def execute(self) -> SingleAPIResponse:
raise APIError(generate_default_error_message(r))


class AsyncExplainRequestBuilder:
def __init__(self, request: ReqConfig):
self.request = request

async def execute(self) -> str:
r = await self.request.send()
try:
if r.is_success:
return r.text
else:
json_obj = model_validate_json(APIErrorFromJSON, r.content)
raise APIError(dict(json_obj))
except ValidationError as e:
raise APIError(generate_default_error_message(r))


class AsyncMaybeSingleRequestBuilder:
def __init__(self, request: ReqConfig):
self.request = request
Expand Down Expand Up @@ -176,6 +181,52 @@ def csv(self) -> AsyncSingleRequestBuilder:
self.request.headers["Accept"] = "text/csv"
return AsyncSingleRequestBuilder(self.request)

@overload
def explain(
self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
format: Literal["text"] = "text",
) -> AsyncExplainRequestBuilder: ...

@overload
def explain(
self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
*,
format: Literal["json"],
) -> AsyncSingleRequestBuilder: ...

def explain(
self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
format: Literal["text", "json"] = "text",
) -> AsyncExplainRequestBuilder | AsyncSingleRequestBuilder:
options = [
key
for key, value in locals().items()
if key not in ["self", "format"] and value
]
options_str = "|".join(options)
self.request.headers["Accept"] = (
f"application/vnd.pgrst.plan+{format}; options={options_str}"
)
if format == "text":
return AsyncExplainRequestBuilder(self.request)
else:
return AsyncSingleRequestBuilder(self.request)


class AsyncRequestBuilder: #
def __init__(
Expand Down
77 changes: 64 additions & 13 deletions src/postgrest/src/postgrest/_sync/request_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Generic, Optional, TypeVar, Union
from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload

from httpx import BasicAuth, Client, Headers, QueryParams, Response
from pydantic import ValidationError
Expand Down Expand Up @@ -32,7 +32,7 @@ class SyncQueryRequestBuilder:
def __init__(self, request: ReqConfig):
self.request = request

def execute(self) -> APIResponse | str:
def execute(self) -> APIResponse:
"""Execute the query.

.. tip::
Expand All @@ -47,17 +47,6 @@ def execute(self) -> APIResponse | str:
r = self.request.send()
try:
if r.is_success:
if self.request.http_method != "HEAD":
body = r.text
if self.request.headers.get("Accept") == "text/csv":
return body
if self.request.headers.get(
"Accept"
) and "application/vnd.pgrst.plan" in self.request.headers.get(
"Accept"
):
if "+json" not in self.request.headers.get("Accept"):
return body
return APIResponse.from_http_request_response(r)
else:
json_obj = model_validate_json(APIErrorFromJSON, r.content)
Expand Down Expand Up @@ -95,6 +84,22 @@ def execute(self) -> SingleAPIResponse:
raise APIError(generate_default_error_message(r))


class SyncExplainRequestBuilder:
def __init__(self, request: ReqConfig):
self.request = request

def execute(self) -> str:
r = self.request.send()
try:
if r.is_success:
return r.text
else:
json_obj = model_validate_json(APIErrorFromJSON, r.content)
raise APIError(dict(json_obj))
except ValidationError as e:
raise APIError(generate_default_error_message(r))


class SyncMaybeSingleRequestBuilder:
def __init__(self, request: ReqConfig):
self.request = request
Expand Down Expand Up @@ -176,6 +181,52 @@ def csv(self) -> SyncSingleRequestBuilder:
self.request.headers["Accept"] = "text/csv"
return SyncSingleRequestBuilder(self.request)

@overload
def explain(
self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
format: Literal["text"] = "text",
) -> SyncExplainRequestBuilder: ...

@overload
def explain(
self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
*,
format: Literal["json"],
) -> SyncSingleRequestBuilder: ...

def explain(
self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
format: Literal["text", "json"] = "text",
) -> SyncExplainRequestBuilder | SyncSingleRequestBuilder:
options = [
key
for key, value in locals().items()
if key not in ["self", "format"] and value
]
options_str = "|".join(options)
self.request.headers["Accept"] = (
f"application/vnd.pgrst.plan+{format}; options={options_str}"
)
if format == "text":
return SyncExplainRequestBuilder(self.request)
else:
return SyncSingleRequestBuilder(self.request)


class SyncRequestBuilder: #
def __init__(
Expand Down
20 changes: 0 additions & 20 deletions src/postgrest/src/postgrest/base_request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,26 +557,6 @@ def max_affected(self: Self, value: int) -> Self:


class BaseSelectRequestBuilder(BaseFilterRequestBuilder[C]):
def explain(
self: Self,
analyze: bool = False,
verbose: bool = False,
settings: bool = False,
buffers: bool = False,
wal: bool = False,
format: Literal["text", "json"] = "text",
) -> Self:
options = [
key
for key, value in locals().items()
if key not in ["self", "format"] and value
]
options_str = "|".join(options)
self.request.headers["Accept"] = (
f"application/vnd.pgrst.plan+{format}; options={options_str}"
)
return self

def order(
self: Self,
column: str,
Expand Down