diff --git a/src/core.c b/src/core.c index cae7dc1..77528ef 100644 --- a/src/core.c +++ b/src/core.c @@ -9,13 +9,6 @@ #include "event.h" #include "errors.h" -typedef struct { - int64 id; - StringInfo body; - struct curl_slist* request_headers; - int32 timeout_milliseconds; -} CurlData; - static SPIPlanPtr del_response_plan = NULL; static SPIPlanPtr del_return_queue_plan = NULL; static SPIPlanPtr ins_response_plan = NULL; @@ -23,9 +16,9 @@ static SPIPlanPtr ins_response_plan = NULL; static size_t body_cb(void *contents, size_t size, size_t nmemb, void *userp) { - CurlData *cdata = (CurlData*) userp; + CurlHandle *handle = (CurlHandle*) userp; size_t realsize = size * nmemb; - appendBinaryStringInfo(cdata->body, (const char*)contents, (int)realsize); + appendBinaryStringInfo(handle->body, (const char*)contents, (int)realsize); return realsize; } @@ -52,85 +45,73 @@ static struct curl_slist *pg_text_array_to_slist(ArrayType *array, return headers; } -// We need a different memory context here, as the parent function will have an SPI memory context, which has a shorter lifetime. -static void init_curl_handle(CURLM *curl_mhandle, MemoryContext curl_memctx, int64 id, Datum urlBin, NullableDatum bodyBin, NullableDatum headersBin, Datum methodBin, int32 timeout_milliseconds){ - MemoryContext old_ctx = MemoryContextSwitchTo(curl_memctx); - - CurlData *cdata = palloc(sizeof(CurlData)); - cdata->id = id; - cdata->body = makeStringInfo(); +void init_curl_handle(CurlHandle *handle, RequestQueueRow row){ + handle->id = row.id; + handle->body = makeStringInfo(); + handle->ez_handle = curl_easy_init(); - cdata->timeout_milliseconds = timeout_milliseconds; + handle->timeout_milliseconds = row.timeout_milliseconds; - if (!headersBin.isnull) { - ArrayType *pgHeaders = DatumGetArrayTypeP(headersBin.value); + if (!row.headersBin.isnull) { + ArrayType *pgHeaders = DatumGetArrayTypeP(row.headersBin.value); struct curl_slist *request_headers = NULL; request_headers = pg_text_array_to_slist(pgHeaders, request_headers); EREPORT_CURL_SLIST_APPEND(request_headers, "User-Agent: pg_net/" EXTVERSION); - cdata->request_headers = request_headers; + handle->request_headers = request_headers; } - char *url = TextDatumGetCString(urlBin); + handle->url = TextDatumGetCString(row.url); - char *reqBody = !bodyBin.isnull ? TextDatumGetCString(bodyBin.value) : NULL; + handle->req_body = !row.bodyBin.isnull ? TextDatumGetCString(row.bodyBin.value) : NULL; - char *method = TextDatumGetCString(methodBin); - if (strcasecmp(method, "GET") != 0 && strcasecmp(method, "POST") != 0 && strcasecmp(method, "DELETE") != 0) { - ereport(ERROR, errmsg("Unsupported request method %s", method)); - } + handle->method = TextDatumGetCString(row.method); - CURL *curl_ez_handle = curl_easy_init(); - if(!curl_ez_handle) - ereport(ERROR, errmsg("curl_easy_init()")); + if (strcasecmp(handle->method, "GET") != 0 && strcasecmp(handle->method, "POST") != 0 && strcasecmp(handle->method, "DELETE") != 0) { + ereport(ERROR, errmsg("Unsupported request method %s", handle->method)); + } - if (strcasecmp(method, "GET") == 0) { - if (reqBody) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDS, reqBody); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_CUSTOMREQUEST, "GET"); + if (strcasecmp(handle->method, "GET") == 0) { + if (handle->req_body) { + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_POSTFIELDS, handle->req_body); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_CUSTOMREQUEST, "GET"); } } - if (strcasecmp(method, "POST") == 0) { - if (reqBody) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDS, reqBody); + if (strcasecmp(handle->method, "POST") == 0) { + if (handle->req_body) { + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_POSTFIELDS, handle->req_body); } else { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POST, 1L); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDSIZE, 0L); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_POST, 1L); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_POSTFIELDSIZE, 0L); } } - if (strcasecmp(method, "DELETE") == 0) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_CUSTOMREQUEST, "DELETE"); - if (reqBody) { - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_POSTFIELDS, reqBody); + if (strcasecmp(handle->method, "DELETE") == 0) { + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + if (handle->req_body) { + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_POSTFIELDS, handle->req_body); } } - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_WRITEFUNCTION, body_cb); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_WRITEDATA, cdata); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_HEADER, 0L); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_URL, url); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_HTTPHEADER, cdata->request_headers); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_TIMEOUT_MS, (long) cdata->timeout_milliseconds); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_PRIVATE, cdata); - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_FOLLOWLOCATION, (long) true); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_WRITEFUNCTION, body_cb); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_WRITEDATA, handle); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_HEADER, 0L); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_URL, handle->url); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_HTTPHEADER, handle->request_headers); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_TIMEOUT_MS, (long) handle->timeout_milliseconds); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_PRIVATE, handle); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_FOLLOWLOCATION, (long) true); if (log_min_messages <= DEBUG2) - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_VERBOSE, 1L); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_VERBOSE, 1L); #if LIBCURL_VERSION_NUM >= 0x075500 /* libcurl 7.85.0 */ - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_PROTOCOLS_STR, "http,https"); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_PROTOCOLS_STR, "http,https"); #else - EREPORT_CURL_SETOPT(curl_ez_handle, CURLOPT_PROTOCOLS, CURLPROTO_HTTP | CURLPROTO_HTTPS); + EREPORT_CURL_SETOPT(handle->ez_handle, CURLOPT_PROTOCOLS, CURLPROTO_HTTP | CURLPROTO_HTTPS); #endif - - EREPORT_MULTI( - curl_multi_add_handle(curl_mhandle, curl_ez_handle) - ); - - MemoryContextSwitchTo(old_ctx); } void set_curl_mhandle(WorkerState *wstate){ @@ -141,8 +122,6 @@ void set_curl_mhandle(WorkerState *wstate){ } uint64 delete_expired_responses(char *ttl, int batch_size){ - SPI_connect(); - if (del_response_plan == NULL) { SPIPlanPtr tmp = SPI_prepare("\ WITH\ @@ -178,14 +157,10 @@ uint64 delete_expired_responses(char *ttl, int batch_size){ ereport(ERROR, errmsg("Error expiring response table rows: %s", SPI_result_code_string(ret_code))); } - SPI_finish(); - return affected_rows; } -uint64 consume_request_queue(CURLM *curl_mhandle, int batch_size, MemoryContext curl_memctx){ - SPI_connect(); - +uint64 consume_request_queue(const int batch_size){ if (del_return_queue_plan == NULL) { SPIPlanPtr tmp = SPI_prepare("\ WITH\ @@ -214,47 +189,40 @@ uint64 consume_request_queue(CURLM *curl_mhandle, int batch_size, MemoryContext if (ret_code != SPI_OK_DELETE_RETURNING) ereport(ERROR, errmsg("Error getting http request queue: %s", SPI_result_code_string(ret_code))); - uint64 affected_rows = SPI_processed; - - for (size_t j = 0; j < affected_rows; j++) { - bool tupIsNull = false; - - int64 id = DatumGetInt64(SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 1, &tupIsNull)); - EREPORT_NULL_ATTR(tupIsNull, id); - - int32 timeout_milliseconds = DatumGetInt32(SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 4, &tupIsNull)); - EREPORT_NULL_ATTR(tupIsNull, timeout_milliseconds); + return SPI_processed; +} - Datum method = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 2, &tupIsNull); - EREPORT_NULL_ATTR(tupIsNull, method); +// This has an implicit dependency on the execution of delete_return_request_queue, +// unfortunately we're not able to make this dependency explicit +// due to the design of SPI (which uses global variables) +RequestQueueRow get_request_queue_row(HeapTuple spi_tupval, TupleDesc spi_tupdesc){ + bool tupIsNull = false; - Datum url = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 3, &tupIsNull); - EREPORT_NULL_ATTR(tupIsNull, url); + int64 id = DatumGetInt64(SPI_getbinval(spi_tupval, spi_tupdesc, 1, &tupIsNull)); + EREPORT_NULL_ATTR(tupIsNull, id); - NullableDatum headersBin = { - .value = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 5, &tupIsNull), - .isnull = tupIsNull - }; + Datum method = SPI_getbinval(spi_tupval, spi_tupdesc, 2, &tupIsNull); + EREPORT_NULL_ATTR(tupIsNull, method); - NullableDatum bodyBin = { - .value = SPI_getbinval(SPI_tuptable->vals[j], SPI_tuptable->tupdesc, 6, &tupIsNull), - .isnull = tupIsNull - }; + Datum url = SPI_getbinval(spi_tupval, spi_tupdesc, 3, &tupIsNull); + EREPORT_NULL_ATTR(tupIsNull, url); - init_curl_handle(curl_mhandle, curl_memctx, id, url, bodyBin, headersBin, method, timeout_milliseconds); - } + int32 timeout_milliseconds = DatumGetInt32(SPI_getbinval(spi_tupval, spi_tupdesc, 4, &tupIsNull)); + EREPORT_NULL_ATTR(tupIsNull, timeout_milliseconds); - SPI_finish(); + NullableDatum headersBin = { + .value = SPI_getbinval(spi_tupval, spi_tupdesc, 5, &tupIsNull), + .isnull = tupIsNull + }; - return affected_rows; -} + NullableDatum bodyBin = { + .value = SPI_getbinval(spi_tupval, spi_tupdesc, 6, &tupIsNull), + .isnull = tupIsNull + }; -static void pfree_curl_data(CurlData *cdata){ - if(cdata->body){ - destroyStringInfo(cdata->body); - } - if(cdata->request_headers) //curl_slist_free_all already handles the NULL case, but be explicit about it - curl_slist_free_all(cdata->request_headers); + return (RequestQueueRow){ + id, method, url, timeout_milliseconds, headersBin, bodyBin + }; } static Jsonb *jsonb_headers_from_curl_handle(CURL *ez_handle){ @@ -276,25 +244,25 @@ static Jsonb *jsonb_headers_from_curl_handle(CURL *ez_handle){ return jsonb_headers; } -static void insert_response(CURL *ez_handle, CurlData *cdata, CURLcode curl_return_code){ +void insert_response(CurlHandle *handle, CURLcode curl_return_code){ enum { nparams = 7 }; // using an enum because const size_t nparams doesn't compile Datum vals[nparams]; char nulls[nparams]; MemSet(nulls, 'n', nparams); - vals[0] = Int64GetDatum(cdata->id); + vals[0] = Int64GetDatum(handle->id); nulls[0] = ' '; if (curl_return_code == CURLE_OK) { - Jsonb *jsonb_headers = jsonb_headers_from_curl_handle(ez_handle); + Jsonb *jsonb_headers = jsonb_headers_from_curl_handle(handle->ez_handle); long res_http_status_code = 0; - EREPORT_CURL_GETINFO(ez_handle, CURLINFO_RESPONSE_CODE, &res_http_status_code); + EREPORT_CURL_GETINFO(handle->ez_handle, CURLINFO_RESPONSE_CODE, &res_http_status_code); vals[1] = Int32GetDatum(res_http_status_code); nulls[1] = ' '; - if (cdata->body && cdata->body->data[0] != '\0'){ - vals[2] = CStringGetTextDatum(cdata->body->data); + if (handle->body && handle->body->data[0] != '\0'){ + vals[2] = CStringGetTextDatum(handle->body->data); nulls[2] = ' '; } @@ -302,7 +270,7 @@ static void insert_response(CURL *ez_handle, CurlData *cdata, CURLcode curl_retu nulls[3] = ' '; struct curl_header *hdr; - if (curl_easy_header(ez_handle, "content-type", 0, CURLH_HEADER, -1, &hdr) == CURLHE_OK){ + if (curl_easy_header(handle->ez_handle, "content-type", 0, CURLH_HEADER, -1, &hdr) == CURLHE_OK){ vals[4] = CStringGetTextDatum(hdr->value); nulls[4] = ' '; } @@ -314,7 +282,7 @@ static void insert_response(CURL *ez_handle, CurlData *cdata, CURLcode curl_retu char *error_msg = NULL; if (timed_out){ - error_msg = detailed_timeout_strerror(ez_handle, cdata->timeout_milliseconds).msg; + error_msg = detailed_timeout_strerror(handle->ez_handle, handle->timeout_milliseconds).msg; } else { error_msg = (char *) curl_easy_strerror(curl_return_code); } @@ -352,36 +320,15 @@ static void insert_response(CURL *ez_handle, CurlData *cdata, CURLcode curl_retu } } -// Switch back to the curl memory context, which has the curl handles stored -void insert_curl_responses(WorkerState *wstate, MemoryContext curl_memctx){ - MemoryContext old_ctx = MemoryContextSwitchTo(curl_memctx); - int msgs_left=0; - CURLMsg *msg = NULL; - CURLM *curl_mhandle = wstate->curl_mhandle; - - while ((msg = curl_multi_info_read(curl_mhandle, &msgs_left))) { - if (msg->msg == CURLMSG_DONE) { - CURLcode return_code = msg->data.result; - CURL *ez_handle= msg->easy_handle; - CurlData *cdata = NULL; - EREPORT_CURL_GETINFO(ez_handle, CURLINFO_PRIVATE, &cdata); - - SPI_connect(); - insert_response(ez_handle, cdata, return_code); - SPI_finish(); +void pfree_handle(CurlHandle *handle){ + pfree(handle->url); + pfree(handle->method); + if(handle->req_body) + pfree(handle->req_body); - pfree_curl_data(cdata); + if(handle->body) + destroyStringInfo(handle->body); - int res = curl_multi_remove_handle(curl_mhandle, ez_handle); - if(res != CURLM_OK) - ereport(ERROR, errmsg("curl_multi_remove_handle: %s", curl_multi_strerror(res))); - - curl_easy_cleanup(ez_handle); - } else { - ereport(ERROR, errmsg("curl_multi_info_read(), CURLMsg=%d\n", msg->msg)); - } - } - - MemoryContextSwitchTo(old_ctx); + if(handle->request_headers) //curl_slist_free_all already handles the NULL case, but be explicit about it + curl_slist_free_all(handle->request_headers); } - diff --git a/src/core.h b/src/core.h index 7c3b85b..a4ee370 100644 --- a/src/core.h +++ b/src/core.h @@ -7,22 +7,51 @@ typedef enum { WS_EXITED, } WorkerStatus; +// the state of the background worker typedef struct { pg_atomic_uint32 got_restart; pg_atomic_uint32 should_wake; pg_atomic_uint32 status; Latch* shared_latch; - ConditionVariable cv; + ConditionVariable cv; // required to publish the state of the worker to other backends int epfd; CURLM *curl_mhandle; } WorkerState; +// A row coming from the http_request_queue +typedef struct { + int64 id; + Datum method; + Datum url; + int32 timeout_milliseconds; + NullableDatum headersBin; + NullableDatum bodyBin; +} RequestQueueRow; + +// The curl easy handle plus additional data, this acts for both the request and response cycle +typedef struct { + int64 id; + StringInfo body; + struct curl_slist* request_headers; + int32 timeout_milliseconds; + char *url; + char *req_body; + char *method; + CURL *ez_handle; +} CurlHandle; + uint64 delete_expired_responses(char *ttl, int batch_size); -uint64 consume_request_queue(CURLM *curl_mhandle, int batch_size, MemoryContext curl_memctx); +uint64 consume_request_queue(const int batch_size); -void insert_curl_responses(WorkerState *wstate, MemoryContext curl_memctx); +RequestQueueRow get_request_queue_row(HeapTuple spi_tupval, TupleDesc spi_tupdesc); void set_curl_mhandle(WorkerState *wstate); +void insert_response(CurlHandle *handle, CURLcode curl_return_code); + +void init_curl_handle(CurlHandle *handle, RequestQueueRow row); + +void pfree_handle(CurlHandle *handle); + #endif diff --git a/src/event.c b/src/event.c index cf7dcf0..2663d15 100644 --- a/src/event.c +++ b/src/event.c @@ -73,12 +73,12 @@ int multi_socket_cb(__attribute__ ((unused)) CURL *easy, curl_socket_t sockfd, i int epoll_op; if(!socketp){ epoll_op = EPOLL_CTL_ADD; - bool *socket_exists = palloc(sizeof(bool)); - curl_multi_assign(wstate->curl_mhandle, sockfd, socket_exists); + bool socket_exists = true; + curl_multi_assign(wstate->curl_mhandle, sockfd, &socket_exists); } else if (what == CURL_POLL_REMOVE){ epoll_op = EPOLL_CTL_DEL; - pfree(socketp); - curl_multi_assign(wstate->curl_mhandle, sockfd, NULL); + bool socket_exists = false; + curl_multi_assign(wstate->curl_mhandle, sockfd, &socket_exists); } else { epoll_op = EPOLL_CTL_MOD; } diff --git a/src/worker.c b/src/worker.c index d4b92ed..63d1194 100644 --- a/src/worker.c +++ b/src/worker.c @@ -36,7 +36,6 @@ static char* guc_ttl; static int guc_batch_size; static char* guc_database_name; static char* guc_username; -static MemoryContext CurlMemContext = NULL; #if PG15_GTE static shmem_request_hook_type prev_shmem_request_hook = NULL; @@ -290,17 +289,31 @@ void pg_net_worker(__attribute__ ((unused)) Datum main_arg) { break; } + SPI_connect(); + expired_responses = delete_expired_responses(guc_ttl, guc_batch_size); elog(DEBUG1, "Deleted "UINT64_FORMAT" expired rows", expired_responses); - requests_consumed = consume_request_queue(worker_state->curl_mhandle, guc_batch_size, CurlMemContext); + requests_consumed = consume_request_queue(guc_batch_size); elog(DEBUG1, "Consumed "UINT64_FORMAT" request rows", requests_consumed); if(requests_consumed > 0){ + CurlHandle *handles = palloc(mul_size(sizeof(CurlHandle), requests_consumed)); + + // initialize curl handles + for (size_t j = 0; j < requests_consumed; j++) { + init_curl_handle(&handles[j], get_request_queue_row(SPI_tuptable->vals[j], SPI_tuptable->tupdesc)); + + EREPORT_MULTI( + curl_multi_add_handle(worker_state->curl_mhandle, handles[j].ez_handle) + ); + } + + // start curl event loop int running_handles = 0; - int maxevents = guc_batch_size + 1; // 1 extra for the timer + int maxevents = requests_consumed + 1; // 1 extra for the timer event events[maxevents]; do { @@ -334,22 +347,43 @@ void pg_net_worker(__attribute__ ((unused)) Datum main_arg) { &running_handles) ); } - } - insert_curl_responses(worker_state, CurlMemContext); + // insert finished responses + CURLMsg *msg = NULL; int msgs_left=0; + while ((msg = curl_multi_info_read(worker_state->curl_mhandle, &msgs_left))) { + if (msg->msg == CURLMSG_DONE) { + CurlHandle *handle = NULL; EREPORT_CURL_GETINFO(msg->easy_handle, CURLINFO_PRIVATE, &handle); + insert_response(handle, msg->data.result); + } else { + ereport(ERROR, errmsg("curl_multi_info_read(), CURLMsg=%d\n", msg->msg)); + } + } elog(DEBUG1, "Pending curl running_handles: %d", running_handles); } while (running_handles > 0); // run while there are curl handles, some won't finish in a single iteration since they could be slow and waiting for a timeout + + // cleanup + for(uint64 i = 0; i < requests_consumed; i++){ + EREPORT_MULTI( + curl_multi_remove_handle(worker_state->curl_mhandle, handles[i].ez_handle) + ); + + curl_easy_cleanup(handles[i].ez_handle); + + pfree_handle(&handles[i]); + } + + pfree(handles); } + SPI_finish(); + unlock_extension(ext_table_oids); PopActiveSnapshot(); CommitTransactionCommand(); - MemoryContextReset(CurlMemContext); - // slow down queue processing to avoid using too much CPU wait_while_processing_interrupts(WORKER_WAIT_ONE_SECOND, &worker_should_restart); @@ -430,12 +464,6 @@ void _PG_init(void) { prev_shmem_startup_hook = shmem_startup_hook; shmem_startup_hook = net_shmem_startup; - CurlMemContext = AllocSetContextCreate(TopMemoryContext, - "pg_net curl context", - ALLOCSET_DEFAULT_MINSIZE, - ALLOCSET_DEFAULT_INITSIZE, - ALLOCSET_DEFAULT_MAXSIZE); - DefineCustomStringVariable("pg_net.ttl", "time to live for request/response rows", "should be a valid interval type",