diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index 6a0cacd81131..88880b9ff992 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -103,7 +103,7 @@ -type value() :: string(). -type header() :: {Field :: field(), Value :: value()}. -type headers() :: [header()]. --type body() :: string() | binary(). +-type body() :: iodata(). -type ssl_options() :: [ssl:tls_client_option()]. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index e0c85ec55372..94452827e008 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -10,15 +10,20 @@ %% API exports -export([ - get/2, get/3, + get/2, get/3, get/4, + put/4, put/5, post/4, refresh_credentials/0, request/5, request/6, request/7, set_credentials/2, has_credentials/0, + parse_uri/1, set_region/1, ensure_imdsv2_token_valid/0, - api_get_request/2 + api_get_request/2, + status_text/1, + open_connection/1, open_connection/2, + close_connection/1 ]). %% gen-server exports @@ -40,23 +45,33 @@ -include("rabbitmq_aws.hrl"). -include_lib("kernel/include/logger.hrl"). +%% Types for new concurrent API +-type connection_handle() :: {pid(), credential_context()}. +-type credential_context() :: #{ + access_key => access_key(), + secret_access_key => secret_access_key(), + security_token => security_token(), + region => region(), + service => string() +}. + %%==================================================================== %% exported wrapper functions %%==================================================================== -spec get( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Path :: path() ) -> result(). %% @doc Perform a HTTP GET request to the AWS API for the specified service. The %% response will automatically be decoded if it is either in JSON, or XML %% format. %% @end -get(Service, Path) -> - get(Service, Path, []). +get(ServiceOrHandle, Path) -> + get(ServiceOrHandle, Path, []). -spec get( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Path :: path(), Headers :: headers() ) -> result(). @@ -64,11 +79,30 @@ get(Service, Path) -> %% response will automatically be decoded if it is either in JSON or XML %% format. %% @end -get(Service, Path, Headers) -> - request(Service, get, Path, "", Headers). +get(ServiceOrHandle, Path, Headers) -> + get(ServiceOrHandle, Path, Headers, []). + +get(Service, Path, Headers, Options) -> + request(Service, get, Path, <<>>, Headers, Options). -spec post( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), + Path :: path(), + Body :: body(), + Headers :: headers() +) -> result(). +%% @doc Perform a HTTP Post request to the AWS API for the specified service. The +%% response will automatically be decoded if it is either in JSON or XML +%% format. +%% @end +post(ServiceOrHandle, Path, Body, Headers) -> + post(ServiceOrHandle, Path, Body, Headers, []). + +post(Service, Path, Body, Headers, Options) -> + request(Service, post, Path, Body, Headers, Options). + +-spec put( + ServiceOrHandle :: string() | connection_handle(), Path :: path(), Body :: body(), Headers :: headers() @@ -77,8 +111,11 @@ get(Service, Path, Headers) -> %% response will automatically be decoded if it is either in JSON or XML %% format. %% @end -post(Service, Path, Body, Headers) -> - request(Service, post, Path, Body, Headers). +put(ServiceOrHandle, Path, Body, Headers) -> + put(ServiceOrHandle, Path, Body, Headers, []). + +put(Service, Path, Body, Headers, Options) -> + request(Service, put, Path, Body, Headers, Options). -spec refresh_credentials() -> ok | error. %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. @@ -86,6 +123,46 @@ post(Service, Path, Body, Headers) -> refresh_credentials() -> gen_server:call(rabbitmq_aws, refresh_credentials). +%%==================================================================== +%% New Concurrent API Functions +%%==================================================================== + +%% Open a connection and return handle for direct use +-spec open_connection(Service :: string()) -> {ok, connection_handle()} | {error, term()}. +open_connection(Service) -> + open_connection(Service, []). + +-spec open_connection(Service :: string(), Options :: list()) -> + {ok, connection_handle()} | {error, term()}. +open_connection(Service, Options) -> + gen_server:call(?MODULE, {open_direct_connection, Service, Options}). + +%% Close a direct connection +-spec close_connection(Handle :: connection_handle()) -> ok. +close_connection({GunPid, _CredContext}) -> + gun:close(GunPid). + +-spec direct_request( + Handle :: connection_handle(), + Method :: method(), + Path :: path(), + Body :: body(), + Headers :: headers(), + Options :: list() +) -> result(). +direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> + #{service := Service, region := Region} = CredContext, + % Build URI for signing + Host = endpoint_host(Region, Service), + URI = create_uri(Host, Path), + BodyHash = proplists:get_value(payload_hash, Options), + % Sign headers directly (no gen_server call) + SignedHeaders = sign_headers_with_context( + CredContext, Method, URI, Headers, Body, BodyHash + ), + % Make Gun request directly + direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). + -spec refresh_credentials(state()) -> ok | error. %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. %% @end @@ -107,7 +184,7 @@ refresh_credentials(State) -> %% format. %% @end request(Service, Method, Path, Body, Headers) -> - gen_server:call(rabbitmq_aws, {request, Service, Method, Headers, Path, Body, [], undefined}). + request(Service, Method, Path, Body, Headers, []). -spec request( Service :: string(), @@ -122,12 +199,10 @@ request(Service, Method, Path, Body, Headers) -> %% format. %% @end request(Service, Method, Path, Body, Headers, HTTPOptions) -> - gen_server:call( - rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, undefined} - ). + request(Service, Method, Path, Body, Headers, HTTPOptions, undefined). -spec request( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Method :: method(), Path :: path(), Body :: body(), @@ -140,6 +215,10 @@ request(Service, Method, Path, Body, Headers, HTTPOptions) -> %% of services such as DynamoDB. The response will automatically be decoded %% if it is either in JSON or XML format. %% @end +request({GunPid, _CredContext} = Handle, Method, Path, Body, Headers, HTTPOptions, _) when + is_pid(GunPid) +-> + direct_request(Handle, Method, Path, Body, Headers, HTTPOptions); request(Service, Method, Path, Body, Headers, HTTPOptions, Endpoint) -> gen_server:call( rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, Endpoint} @@ -186,9 +265,10 @@ start_link() -> -spec init(list()) -> {ok, state()}. init([]) -> + {ok, _} = application:ensure_all_started(gun), {ok, #state{}}. -terminate(_, _) -> +terminate(_, _State) -> ok. code_change(_, _, State) -> @@ -211,6 +291,18 @@ handle_msg({request, Service, Method, Headers, Path, Body, Options, Host}, State State, Service, Method, Headers, Path, Body, Options, Host ), {reply, Response, NewState}; +handle_msg({open_direct_connection, Service, Options}, State) -> + case ensure_credentials_valid_internal(State) of + {ok, ValidState} -> + case create_direct_connection(ValidState, Service, Options) of + {ok, Handle} -> + {reply, {ok, Handle}, ValidState}; + {error, Reason} -> + {reply, {error, Reason}, ValidState} + end; + {error, Reason} -> + {reply, {error, Reason}, State} + end; handle_msg(get_state, State) -> {reply, {ok, State}, State}; handle_msg(refresh_credentials, State) -> @@ -282,6 +374,8 @@ endpoint_tld(_Other) -> %% @end format_response({ok, {{_Version, 200, _Message}, Headers, Body}}) -> {ok, {Headers, maybe_decode_body(get_content_type(Headers), Body)}}; +format_response({ok, {{_Version, 206, _Message}, Headers, Body}}) -> + {ok, {Headers, maybe_decode_body(get_content_type(Headers), Body)}}; format_response({ok, {{_Version, StatusCode, Message}, Headers, Body}}) when StatusCode >= 400 -> {error, Message, {Headers, maybe_decode_body(get_content_type(Headers), Body)}}; format_response({error, Reason}) -> @@ -293,9 +387,9 @@ format_response({error, Reason}) -> %% @end get_content_type(Headers) -> Value = - case proplists:get_value("content-type", Headers, undefined) of + case proplists:get_value(<<"content-type">>, Headers, undefined) of undefined -> - proplists:get_value("Content-Type", Headers, "text/xml"); + proplists:get_value(<<"Content-Type">>, Headers, "text/xml"); Other -> Other end, @@ -368,6 +462,8 @@ local_time() -> list() | body(). %% @doc Attempt to decode the response body by its MIME %% @end +maybe_decode_body(_, <<>>) -> + <<>>; maybe_decode_body({"application", "x-amz-json-1.0"}, Body) -> rabbitmq_aws_json:decode(Body); maybe_decode_body({"application", "json"}, Body) -> @@ -380,6 +476,8 @@ maybe_decode_body(_ContentType, Body) -> -spec parse_content_type(ContentType :: string()) -> {Type :: string(), Subtype :: string()}. %% @doc parse a content type string returning a tuple of type/subtype %% @end +parse_content_type(ContentType) when is_binary(ContentType) -> + parse_content_type(binary_to_list(ContentType)); parse_content_type(ContentType) -> Parts = string:tokens(ContentType, ";"), [Type, Subtype] = string:tokens(lists:nth(1, Parts), "/"), @@ -480,15 +578,13 @@ perform_request_creds_expired(true, State, _, _, _, _, _, _, _) -> perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host) -> URI = endpoint(State, Host, Service, Path), SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body), - ContentType = proplists:get_value("content-type", SignedHeaders, undefined), - perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options). + perform_request_with_creds(State, Method, URI, SignedHeaders, Body, Options). -spec perform_request_with_creds( State :: state(), Method :: method(), URI :: string(), Headers :: headers(), - ContentType :: string() | undefined, Body :: body(), Options :: http_options() ) -> @@ -496,14 +592,12 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, %% @doc Once it is validated that there are credentials to try and that they have not %% expired, perform the request and return the response. %% @end -perform_request_with_creds(State, Method, URI, Headers, undefined, "", Options0) -> - Options1 = ensure_timeout(Options0), - Response = httpc:request(Method, {URI, Headers}, Options1, []), - {format_response(Response), State}; -perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Options0) -> - Options1 = ensure_timeout(Options0), - Response = httpc:request(Method, {URI, Headers, ContentType, Body}, Options1, []), - {format_response(Response), State}. +perform_request_with_creds(State, Method, URI, Headers, "", Options0) -> + Response = gun_request(Method, URI, Headers, <<>>, Options0), + {Response, State}; +perform_request_with_creds(State, Method, URI, Headers, Body, Options0) -> + Response = gun_request(Method, URI, Headers, Body, Options0), + {Response, State}. -spec perform_request_creds_error(State :: state()) -> {result_error(), NewState :: state()}. @@ -513,22 +607,6 @@ perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Optio perform_request_creds_error(State) -> {{error, {credentials, State#state.error}}, State}. -%% @doc Ensure that the timeout option is set and greater than 0 and less -%% than about 1/2 of the default gen_server:call timeout. This gives -%% enough time for a long connect and request phase to succeed. -%% @end --spec ensure_timeout(Options :: http_options()) -> http_options(). -ensure_timeout(Options) -> - case proplists:get_value(timeout, Options) of - undefined -> - Options ++ [{timeout, ?DEFAULT_HTTP_TIMEOUT}]; - Value when is_integer(Value) andalso Value >= 0 andalso Value =< ?DEFAULT_HTTP_TIMEOUT -> - Options; - _ -> - Options1 = proplists:delete(timeout, Options), - Options1 ++ [{timeout, ?DEFAULT_HTTP_TIMEOUT}] - end. - -spec sign_headers( State :: state(), Service :: string(), @@ -648,3 +726,207 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) -> timer:sleep(WaitTimeBetweenRetries), api_get_request_with_retries(Service, Path, Retries - 1, WaitTimeBetweenRetries) end. + +%% Gun HTTP client functions +gun_request(Method, URI, Headers, Body, Options) -> + {Host, Port, Path} = parse_uri(URI), + GunPid = create_gun_connection(Host, Port, Options), + Reply = direct_gun_request(GunPid, Method, Path, Headers, Body, Options), + gun:close(GunPid), + Reply. + +do_gun_request(ConnPid, get, Path, Headers, _Body) -> + gun:get(ConnPid, Path, Headers); +do_gun_request(ConnPid, post, Path, Headers, Body) -> + gun:post(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, put, Path, Headers, Body) -> + gun:put(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, head, Path, Headers, _Body) -> + gun:head(ConnPid, Path, Headers, #{}); +do_gun_request(ConnPid, delete, Path, Headers, _Body) -> + gun:delete(ConnPid, Path, Headers, #{}); +do_gun_request(ConnPid, patch, Path, Headers, Body) -> + gun:patch(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, options, Path, Headers, _Body) -> + gun:options(ConnPid, Path, Headers, #{}). + +create_gun_connection(Host, Port, Options) -> + % Map HTTP version to Gun protocols, always include http as fallback + HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), + Protocols = + case HttpVersion of + "HTTP/2" -> [http2, http]; + "HTTP/2.0" -> [http2, http]; + "HTTP/1.1" -> [http]; + "HTTP/1.0" -> [http]; + % Default: try HTTP/2, fallback to HTTP/1.1 + _ -> [http2, http] + end, + ConnectTimeout = proplists:get_value(connect_timeout, Options, infinity), + Opts = #{ + transport => + if + Port == 443 -> tls; + true -> tcp + end, + protocols => Protocols, + connect_timeout => ConnectTimeout + }, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, ConnectTimeout) of + {ok, _Protocol} -> + ConnPid; + {error, Reason} -> + gun:close(ConnPid), + error({gun_connection_failed, Reason}) + end; + {error, Reason} -> + error({gun_open_failed, Reason}) + end. + +create_uri(Host, Path) when is_list(Path) -> + "https://" ++ Host ++ Path; +create_uri(Host, {Bucket, Key}) -> + "https://" ++ Bucket ++ "." ++ Host ++ "/" ++ Key. + +parse_uri(URI) -> + case string:split(URI, "://", leading) of + [Scheme, Rest] -> + case string:split(Rest, "/", leading) of + [HostPort] -> + {Host, Port} = parse_host_port(HostPort, Scheme), + {Host, Port, "/"}; + [HostPort, Path] -> + {Host, Port} = parse_host_port(HostPort, Scheme), + {Host, Port, "/" ++ Path} + end + end. + +parse_host_port(HostPort, Scheme) -> + DefaultPort = + case Scheme of + "https" -> 443; + "http" -> 80; + % Fallback to HTTPS + _ -> 443 + end, + case string:split(HostPort, ":", trailing) of + [Host] -> + {Host, DefaultPort}; + [Host, PortStr] -> + {Host, list_to_integer(PortStr)} + end. + +status_text(200) -> "OK"; +status_text(206) -> "Partial Content"; +status_text(400) -> "Bad Request"; +status_text(401) -> "Unauthorized"; +status_text(403) -> "Forbidden"; +status_text(404) -> "Not Found"; +status_text(416) -> "Range Not Satisfiable"; +status_text(500) -> "Internal Server Error"; +status_text(Code) -> integer_to_list(Code). + +%%==================================================================== +%% New Concurrent API Helper Functions +%%==================================================================== + +%% Create a direct connection handle +-spec create_direct_connection(State :: state(), Service :: string(), Options :: list()) -> + {ok, connection_handle()} | {error, term()}. +create_direct_connection(State, Service, Options) -> + Region = State#state.region, + Host = endpoint_host(Region, Service), + Port = 443, + GunPid = create_gun_connection(Host, Port, Options), + CredContext = #{ + access_key => State#state.access_key, + secret_access_key => State#state.secret_access_key, + security_token => State#state.security_token, + region => Region, + service => Service + }, + {ok, {GunPid, CredContext}}. + +%% Sign headers using credential context (no gen_server state needed) +-spec sign_headers_with_context( + CredContext :: credential_context(), + Method :: method(), + URI :: string(), + Headers :: headers(), + Body :: body(), + BodyHash :: iodata() +) -> headers(). +sign_headers_with_context(CredContext, Method, URI, Headers, Body, BodyHash) -> + #{ + access_key := AccessKey, + secret_access_key := SecretKey, + security_token := SecurityToken, + region := Region, + service := Service + } = CredContext, + rabbitmq_aws_sign:headers( + #request{ + access_key = AccessKey, + secret_access_key = SecretKey, + security_token = SecurityToken, + region = Region, + service = Service, + method = Method, + uri = URI, + headers = Headers, + body = Body + }, + BodyHash + ). + +%% Direct Gun request (extracted from existing gun_request function) +-spec direct_gun_request( + GunPid :: pid(), + Method :: method(), + Path :: path(), + Headers :: headers(), + Body :: body(), + Options :: list() +) -> result(). +direct_gun_request(GunPid, Method, {_, Path}, Headers, Body, Options) -> + direct_gun_request(GunPid, Method, [$/ | Path], Headers, Body, Options); +direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> + HeadersBin = lists:map( + fun({Key, Value}) -> + {list_to_binary(Key), list_to_binary(Value)} + end, + Headers + ), + Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), + Response = + try + StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body), + case gun:await(GunPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}; + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout), + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}; + {error, Reason} -> + {error, Reason} + end + catch + _:Error -> + {error, Error} + end, + format_response(Response). + +%% Internal credential validation (extracted from existing logic) +-spec ensure_credentials_valid_internal(State :: state()) -> {ok, state()} | {error, term()}. +ensure_credentials_valid_internal(State) -> + case has_credentials(State) of + true -> + case expired_credentials(State#state.expiration) of + false -> {ok, State}; + true -> load_credentials(State) + end; + false -> + load_credentials(State) + end. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl index 3d2ae89fe918..4ba821249a99 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl @@ -629,9 +629,14 @@ maybe_get_role_from_instance_metadata() -> %% @doc Parse the response from the Availability Zone query to the %% Instance Metadata service, returning the Region if successful. %% end. -parse_az_response({error, _}) -> {error, undefined}; -parse_az_response({ok, {{_, 200, _}, _, Body}}) -> {ok, region_from_availability_zone(Body)}; -parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}. +parse_az_response({error, _}) -> + {error, undefined}; +parse_az_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) -> + {ok, region_from_availability_zone(binary_to_list(Body))}; +parse_az_response({ok, {{_, 200, _}, _, Body}}) -> + {ok, region_from_availability_zone(Body)}; +parse_az_response({ok, {{_, _, _}, _, _}}) -> + {error, undefined}. -spec parse_body_response(httpc_result()) -> {ok, Value :: string()} | {error, Reason :: atom()}. @@ -640,8 +645,9 @@ parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}. %% end. parse_body_response({error, _}) -> {error, undefined}; -parse_body_response({ok, {{_, 200, _}, _, Body}}) -> - {ok, Body}; +parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) -> + {ok, binary_to_list(Body)}; +parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_list(Body) -> {ok, Body}; parse_body_response({ok, {{_, 401, _}, _, _}}) -> ?LOG_ERROR( get_instruction_on_instance_metadata_error( @@ -678,12 +684,47 @@ parse_credentials_response({ok, {{_, 200, _}, _, Body}}) -> %% @end perform_http_get_instance_metadata(URL) -> ?LOG_DEBUG("Querying instance metadata service: ~tp", [URL]), - httpc:request( - get, - {URL, instance_metadata_request_headers()}, - [{timeout, ?DEFAULT_HTTP_TIMEOUT}], - [] - ). + % Parse metadata service URL + {Host, Port, Path} = rabbitmq_aws:parse_uri(URL), + % Simple Gun connection for metadata service + + % HTTP only, no TLS + Opts = #{transport => tcp, protocols => [http]}, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, 5000) of + {ok, _Protocol} -> + Headers = instance_metadata_request_headers(), + StreamRef = gun:get(ConnPid, Path, Headers), + Result = + case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of + {response, fin, Status, RespHeaders} -> + {ok, { + {http_version, Status, rabbitmq_aws:status_text(Status)}, + RespHeaders, + <<>> + }}; + {response, nofin, Status, RespHeaders} -> + {ok, Body} = gun:await_body( + ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT + ), + {ok, { + {http_version, Status, rabbitmq_aws:status_text(Status)}, + RespHeaders, + Body + }}; + {error, Reason} -> + {error, Reason} + end, + gun:close(ConnPid), + Result; + {error, Reason} -> + gun:close(ConnPid), + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end. -spec get_instruction_on_instance_metadata_error(string()) -> string(). %% @doc Return error message on failures related to EC2 Instance Metadata Service with a reference to AWS document. @@ -742,29 +783,77 @@ region_from_availability_zone(Value) -> load_imdsv2_token() -> TokenUrl = imdsv2_token_url(), ?LOG_INFO("Attempting to obtain EC2 IMDSv2 token from ~tp ...", [TokenUrl]), - case - httpc:request( - put, - {TokenUrl, [{?METADATA_TOKEN_TTL_HEADER, integer_to_list(?METADATA_TOKEN_TTL_SECONDS)}]}, - [{timeout, ?DEFAULT_HTTP_TIMEOUT}], - [] - ) - of - {ok, {{_, 200, _}, _, Value}} -> - ?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."), - Value; - {error, {{_, 400, _}, _, _}} -> - ?LOG_WARNING( - "Failed to obtain EC2 IMDSv2 token: Missing or Invalid Parameters – The PUT request is not valid." - ), - undefined; - Other -> + % Parse metadata service URL + {Host, Port, Path} = rabbitmq_aws:parse_uri(TokenUrl), + % Simple Gun connection for metadata service + + % HTTP only, no TLS + Opts = #{transport => tcp, protocols => [http]}, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, 5000) of + {ok, _Protocol} -> + % PUT request with IMDSv2 token TTL header + Headers = [ + {?METADATA_TOKEN_TTL_HEADER, integer_to_list(?METADATA_TOKEN_TTL_SECONDS)} + ], + StreamRef = gun:put(ConnPid, Path, Headers, <<>>), + Result = + case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of + {response, fin, 200, _RespHeaders} -> + ?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."), + % Empty body for fin response + <<>>; + {response, nofin, 200, _RespHeaders} -> + {ok, Body} = gun:await_body( + ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT + ), + ?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."), + binary_to_list(Body); + {response, _, 400, _RespHeaders} -> + ?LOG_WARNING( + "Failed to obtain EC2 IMDSv2 token: Missing or Invalid Parameters – The PUT request is not valid." + ), + undefined; + {error, Reason} -> + ?LOG_WARNING( + get_instruction_on_instance_metadata_error( + "Failed to obtain EC2 IMDSv2 token: ~tp. " + "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." + ), + [Reason] + ), + undefined; + Other -> + ?LOG_WARNING( + get_instruction_on_instance_metadata_error( + "Failed to obtain EC2 IMDSv2 token: ~tp. " + "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." + ), + [Other] + ), + undefined + end, + gun:close(ConnPid), + Result; + {error, Reason} -> + gun:close(ConnPid), + ?LOG_WARNING( + get_instruction_on_instance_metadata_error( + "Failed to connect for EC2 IMDSv2 token: ~tp. " + "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." + ), + [Reason] + ), + undefined + end; + {error, Reason} -> ?LOG_WARNING( get_instruction_on_instance_metadata_error( - "Failed to obtain EC2 IMDSv2 token: ~tp. " + "Failed to open connection for EC2 IMDSv2 token: ~tp. " "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." ), - [Other] + [Reason] ), undefined end. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl index 7a95a2b44e77..c5f9b0bddd9d 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl @@ -8,7 +8,7 @@ -module(rabbitmq_aws_sign). %% API --export([headers/1, request_hash/5]). +-export([headers/1, headers/2, request_hash/5]). %% Export all for unit tests -ifdef(TEST). @@ -24,13 +24,19 @@ %% @doc Create the signed request headers %% end headers(Request) -> + headers(Request, undefined). + +headers(Request, undefined) -> + headers(Request, sha256(Request#request.body)); +headers(Request, PayloadHash) -> RequestTimestamp = local_time(), - PayloadHash = sha256(Request#request.body), URI = rabbitmq_aws_urilib:parse(Request#request.uri), {_, Host, _} = URI#uri.authority, + BodyLength = iolist_size(Request#request.body), + Headers = append_headers( RequestTimestamp, - length(Request#request.body), + BodyLength, PayloadHash, Host, Request#request.security_token, @@ -41,7 +47,7 @@ headers(Request) -> URI#uri.path, URI#uri.query, Headers, - Request#request.body + PayloadHash ), AuthValue = authorization( Request#request.access_key, @@ -202,11 +208,11 @@ query_string(QueryArgs) -> rabbitmq_aws_urilib:build_query_string(lists:keysort( Path :: path(), QArgs :: query_args(), Headers :: headers(), - Payload :: string() + PayloadHash :: string() ) -> string(). %% @doc Create the request hash value %% @end -request_hash(Method, Path, QArgs, Headers, Payload) -> +request_hash(Method, Path, QArgs, Headers, PayloadHash) -> RawPath = case string:slice(Path, 0, 1) of "/" -> Path; @@ -220,7 +226,7 @@ request_hash(Method, Path, QArgs, Headers, Payload) -> query_string(QArgs), canonical_headers(Headers), signed_headers(Headers), - sha256(Payload) + PayloadHash ], "\n" ), @@ -236,7 +242,7 @@ request_hash(Method, Path, QArgs, Headers, Payload) -> scope(AMZDate, Region, Service) -> string:join([AMZDate, Region, Service, "aws4_request"], "/"). --spec sha256(Value :: string()) -> string(). +-spec sha256(Value :: iodata()) -> string(). %% @doc Return the SHA-256 hash for the specified value. %% @end sha256(Value) -> diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl index 250fc1fc882e..98094aea87eb 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl @@ -11,6 +11,8 @@ -include_lib("xmerl/include/xmerl.hrl"). -spec parse(Value :: string() | binary()) -> list(). +parse(Value) when is_binary(Value) -> + parse(binary_to_list(Value)); parse(Value) -> {Element, _} = xmerl_scan:string(Value), parse_node(Element). diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl index cca1b4af8231..fd6c30376c37 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl @@ -120,10 +120,10 @@ credentials_test_() -> { foreach, fun() -> - meck:new(httpc), - meck:new(rabbitmq_aws), + meck:new(gun, []), + meck:new(rabbitmq_aws, [passthrough]), reset_environment(), - [httpc, rabbitmq_aws] + [gun, rabbitmq_aws] end, fun meck:unload/1, [ @@ -222,13 +222,26 @@ credentials_test_() -> {"from instance metadata service", fun() -> CredsBody = "{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2016-03-31T21:51:49Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIAIMAFAKEACCESSKEY\",\n \"SecretAccessKey\" : \"2+t64tZZVaz0yp0x1G23ZRYn+FAKEyVALUEs/4qh\",\n \"Token\" : \"FAKE//////////wEAK/TOKEN/VALUE=\",\n \"Expiration\" : \"2016-04-01T04:13:28Z\"\n}", + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]), meck:sequence( - httpc, - request, - 4, + gun, + await, + 3, [ - {ok, {{protocol, 200, message}, headers, "Bob"}}, - {ok, {{protocol, 200, message}, headers, CredsBody}} + {response, nofin, 200, headers}, + {response, nofin, 200, headers} + ] + ), + meck:sequence( + gun, + await_body, + 3, + [ + {ok, <<"Bob">>}, + {ok, list_to_binary(CredsBody)} ] ), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), @@ -239,41 +252,59 @@ credentials_test_() -> end}, {"with instance metadata service role error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect(httpc, request, 4, {error, timeout}), + meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service role http error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, - request, - 4, - {ok, {{protocol, 500, message}, headers, "Internal Server Error"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 500, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Internal Server Error">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service credentials error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]), meck:sequence( - httpc, - request, - 4, + gun, + await, + 3, [ - {ok, {{protocol, 200, message}, headers, "Bob"}}, + {response, nofin, 200, headers}, {error, timeout} ] ), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Bob">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service credentials not found", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]), + meck:sequence( + gun, + await, + 3, + [ + {response, nofin, 200, headers}, + {response, nofin, 404, headers} + ] + ), meck:sequence( - httpc, - request, - 4, + gun, + await_body, + 3, [ - {ok, {{protocol, 200, message}, headers, "Bob"}}, - {ok, {{protocol, 404, message}, headers, "File Not Found"}} + {ok, <<"Bob">>}, + {ok, <<"File Not Found">>} ] ), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) @@ -357,10 +388,10 @@ region_test_() -> { foreach, fun() -> - meck:new(httpc), - meck:new(rabbitmq_aws), + meck:new(gun, []), + meck:new(rabbitmq_aws, [passthrough]), reset_environment(), - [httpc, rabbitmq_aws] + [gun, rabbitmq_aws] end, fun meck:unload/1, [ @@ -383,12 +414,12 @@ region_test_() -> end}, {"from instance metadata service", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, - request, - 4, - {ok, {{protocol, 200, message}, headers, "us-west-1a"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"us-west-1a">>} end), ?assertEqual({ok, "us-west-1"}, rabbitmq_aws_config:region()) end}, {"full lookup failure", fun() -> @@ -397,12 +428,12 @@ region_test_() -> end}, {"http error failure", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, - request, - 4, - {ok, {{protocol, 500, message}, headers, "Internal Server Error"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 500, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Internal Server Error">>} end), ?assertEqual({ok, ?DEFAULT_REGION}, rabbitmq_aws_config:region()) end} ] @@ -412,32 +443,41 @@ instance_id_test_() -> { foreach, fun() -> - meck:new(httpc), - meck:new(rabbitmq_aws), + meck:new(gun, []), + meck:new(rabbitmq_aws, [passthrough]), reset_environment(), - [httpc, rabbitmq_aws] + [gun, rabbitmq_aws] end, fun meck:unload/1, [ {"get instance id successfully", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, request, 4, {ok, {{protocol, 200, message}, headers, "instance-id"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"instance-id">>} end), ?assertEqual({ok, "instance-id"}, rabbitmq_aws_config:instance_id()) end}, {"getting instance id is rejected with invalid token error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "invalid"), - meck:expect( - httpc, request, 4, {error, {{protocol, 401, message}, headers, "Invalid token"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 401, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Invalid token">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:instance_id()) end}, {"getting instance id is rejected with access denied error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "expired token"), - meck:expect( - httpc, request, 4, {error, {{protocol, 403, message}, headers, "access denied"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 403, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"access denied">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:instance_id()) end} ] @@ -447,36 +487,34 @@ load_imdsv2_token_test_() -> { foreach, fun() -> - meck:new(httpc), - [httpc] + meck:new(gun, []), + [gun] end, fun meck:unload/1, [ {"fail to get imdsv2 token - timeout", fun() -> - meck:expect(httpc, request, 4, {error, timeout}), + meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end), ?assertEqual(undefined, rabbitmq_aws_config:load_imdsv2_token()) end}, {"fail to get imdsv2 token - PUT request is not valid", fun() -> - meck:expect( - httpc, - request, - 4, - {error, { - {protocol, 400, messge}, - headers, - "Missing or Invalid Parameters – The PUT request is not valid." - }} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, put, fun(_, _, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 400, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> + {ok, <<"Missing or Invalid Parameters – The PUT request is not valid.">>} + end), ?assertEqual(undefined, rabbitmq_aws_config:load_imdsv2_token()) end}, {"successfully get imdsv2 token from instance metadata service", fun() -> IMDSv2Token = "super_secret_token_value", - meck:sequence( - httpc, - request, - 4, - [{ok, {{protocol, 200, message}, headers, IMDSv2Token}}] - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, put, fun(_, _, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, list_to_binary(IMDSv2Token)} end), ?assertEqual(IMDSv2Token, rabbitmq_aws_config:load_imdsv2_token()) end} ] @@ -486,7 +524,7 @@ maybe_imdsv2_token_headers_test_() -> { foreach, fun() -> - meck:new(rabbitmq_aws), + meck:new(rabbitmq_aws, [passthrough]), [rabbitmq_aws] end, fun meck:unload/1, @@ -516,7 +554,7 @@ reset_environment() -> "AWS_SHARED_CREDENTIALS_FILE", "bad_credentials.ini" ), - meck:expect(httpc, request, 4, {error, timeout}). + meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end). setup_test_config_env_var() -> setup_test_file_with_env_var("AWS_CONFIG_FILE", "test_aws_config.ini"). diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index 7f5eaa906e44..66c23e0f65cc 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -46,7 +46,13 @@ init_test_() -> ]}. terminate_test() -> - ?assertEqual(ok, rabbitmq_aws:terminate(foo, bar)). + ?assertEqual( + ok, + rabbitmq_aws:terminate( + foo, + {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, test_result} + ) + ). code_change_test() -> ?assertEqual({ok, {state, denial}}, rabbitmq_aws:code_change(foo, bar, {state, denial})). @@ -133,9 +139,11 @@ format_response_test_() -> {"ok", fun() -> Response = {ok, { - {"HTTP/1.1", 200, "Ok"}, [{"Content-Type", "text/xml"}], "Value" + {"HTTP/1.1", 200, "Ok"}, + [{<<"Content-Type">>, <<"text/xml">>}], + "Value" }}, - Expectation = {ok, {[{"Content-Type", "text/xml"}], [{"test", "Value"}]}}, + Expectation = {ok, {[{<<"Content-Type">>, <<"text/xml">>}], [{"test", "Value"}]}}, ?assertEqual(Expectation, rabbitmq_aws:format_response(Response)) end}, {"error", fun() -> @@ -161,8 +169,8 @@ gen_server_call_test_() -> os:putenv("AWS_DEFAULT_REGION", "us-west-3"), os:putenv("AWS_ACCESS_KEY_ID", "Sésame"), os:putenv("AWS_SECRET_ACCESS_KEY", "ouvre-toi"), - meck:new(httpc, []), - [httpc] + meck:new(gun, []), + [gun] end, fun(Mods) -> meck:unload(Mods), @@ -186,31 +194,41 @@ gen_server_call_test_() -> Body = "", Options = [], Host = undefined, + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), meck:expect( - httpc, - request, - fun( - get, - {"https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01", - _Headers}, - _Options, - [] - ) -> - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"pass\": true}" - }} + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end + ), + %% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}} + %% end), + meck:expect( + gun, + await, + fun(_Pid, _, _) -> + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} end ), + meck:expect( + gun, + await_body, + fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end + ), + + %% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}} + %% end), Expectation = - {reply, {ok, {[{"content-type", "application/json"}], [{"pass", true}]}}, + {reply, + {ok, + {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, State}, Result = rabbitmq_aws:handle_call( {request, Service, Method, Headers, Path, Body, Options, Host}, eunit, State ), ?assertEqual(Expectation, Result), - meck:validate(httpc) + meck:validate(gun) end }, { @@ -388,9 +406,9 @@ perform_request_test_() -> { foreach, fun() -> - meck:new(httpc, []), + meck:new(gun, []), meck:new(rabbitmq_aws_config, []), - [httpc, rabbitmq_aws_config] + [gun, rabbitmq_aws_config] end, fun meck:unload/1, [ @@ -411,33 +429,37 @@ perform_request_test_() -> Host = undefined, ExpectURI = "https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01", + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + + meck:expect( + gun, + get, + fun(_Pid, "/?Action=DescribeTags&Version=2015-10-01", _Headers) -> nofin end + ), meck:expect( - httpc, - request, - fun(get, {URI, _Headers}, _Options, []) -> - case URI of - ExpectURI -> - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"pass\": true}" - }}; - _ -> - {ok, - {{"HTTP/1.0", 400, "RequestFailure", - [{"content-type", "application/json"}], - "{\"pass\": false}"}}} - end + gun, + await, + fun(_Pid, _, _) -> + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} end ), + meck:expect( + gun, + await_body, + fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end + ), + Expectation = { - {ok, {[{"content-type", "application/json"}], [{"pass", true}]}}, State + {ok, {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, + State }, Result = rabbitmq_aws:perform_request( State, Service, Method, Headers, Path, Body, Options, Host ), ?assertEqual(Expectation, Result), - meck:validate(httpc) + meck:validate(gun) end }, { @@ -451,19 +473,11 @@ perform_request_test_() -> Body = "", Options = [], Host = undefined, - meck:expect(httpc, request, fun(get, {_URI, _Headers}, _Options, []) -> - {ok, { - {"HTTP/1.0", 400, "RequestFailure"}, - [{"content-type", "application/json"}], - "{\"pass\": false}" - }} - end), Expectation = {{error, {credentials, State#state.error}}, State}, Result = rabbitmq_aws:perform_request( State, Service, Method, Headers, Path, Body, Options, Host ), - ?assertEqual(Expectation, Result), - meck:validate(httpc) + ?assertEqual(Expectation, Result) end }, { @@ -554,9 +568,9 @@ api_get_request_test_() -> { foreach, fun() -> - meck:new(httpc, []), + meck:new(gun, []), meck:new(rabbitmq_aws_config, []), - [httpc, rabbitmq_aws_config] + [gun, rabbitmq_aws_config] end, fun meck:unload/1, [ @@ -567,23 +581,34 @@ api_get_request_test_() -> region = "us-east-1", expiration = {{3016, 4, 1}, {12, 0, 0}} }, + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), meck:expect( - httpc, - request, - 4, - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"data\": \"value\"}" - }} + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end ), + meck:expect( + gun, + await, + fun(_Pid, _, _) -> + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} + end + ), + meck:expect( + gun, + await_body, + fun(_Pid, _, _) -> {ok, <<"{\"data\": \"value\"}">>} end + ), + {ok, Pid} = rabbitmq_aws:start_link(), rabbitmq_aws:set_region("us-east-1"), rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request("AWS", "API"), ok = gen_server:stop(Pid), ?assertEqual({ok, [{"data", "value"}]}, Result), - meck:validate(httpc) + meck:validate(gun) end}, {"AWS service API request failed - credentials", fun() -> meck:expect(rabbitmq_aws_config, credentials, 0, {error, undefined}), @@ -600,14 +625,27 @@ api_get_request_test_() -> region = "us-east-1", expiration = {{3016, 4, 1}, {12, 0, 0}} }, - meck:expect(httpc, request, 4, {error, "network error"}), + meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect( + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end + ), + meck:expect( + gun, + await, + fun(_Pid, _, _) -> {error, "network error"} end + ), + {ok, Pid} = rabbitmq_aws:start_link(), rabbitmq_aws:set_region("us-east-1"), rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1), ok = gen_server:stop(Pid), ?assertEqual({error, "AWS service is unavailable"}, Result), - meck:validate(httpc) + meck:validate(gun) end}, {"AWS service API request succeeded after a transient error", fun() -> State = #state{ @@ -616,22 +654,35 @@ api_get_request_test_() -> region = "us-east-1", expiration = {{3016, 4, 1}, {12, 0, 0}} }, + meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), meck:expect( - httpc, - request, - 4, + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end + ), + + %% meck:expect(gun, get, 3, meck:seq( + %% fun(_Pid, _Path, _Headers) -> {error, "network errors"} end), + meck:expect( + gun, + await, + 3, meck:seq([ {error, "network error"}, - {ok, { - {"HTTP/1.0", 500, "OK"}, - [{"content-type", "application/json"}], - "{\"error\": \"server error\"}" - }}, - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"data\": \"value\"}" - }} + {response, nofin, 500, [{<<"content-type">>, <<"application/json">>}]}, + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} + ]) + ), + + meck:expect( + gun, + await_body, + 3, + meck:seq([ + {ok, <<"{\"error\": \"server error\"}">>}, + {ok, <<"{\"data\": \"value\"}">>} ]) ), {ok, Pid} = rabbitmq_aws:start_link(), @@ -640,7 +691,7 @@ api_get_request_test_() -> Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1), ok = gen_server:stop(Pid), ?assertEqual({ok, [{"data", "value"}]}, Result), - meck:validate(httpc) + meck:validate(gun) end} ] }.