Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for stream_management to escalus_ws #265

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/escalus_connection.erl
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,16 @@ get_stream_end(#client{rcv_pid = Pid, jid = Jid}, Timeout) ->
-spec get_sm_h(client()) -> non_neg_integer().
get_sm_h(#client{module = escalus_tcp, rcv_pid = Pid}) ->
escalus_tcp:get_sm_h(Pid);
get_sm_h(#client{module = escalus_ws, rcv_pid = Pid}) ->
escalus_ws:get_sm_h(Pid);
get_sm_h(#client{module = Mod}) ->
error({get_sm_h, {undefined_for_escalus_module, Mod}}).

-spec set_sm_h(client(), non_neg_integer()) -> {ok, non_neg_integer()}.
set_sm_h(#client{module = escalus_tcp, rcv_pid = Pid}, H) ->
escalus_tcp:set_sm_h(Pid, H);
set_sm_h(#client{module = escalus_ws, rcv_pid = Pid}, H) ->
escalus_ws:set_sm_h(Pid, H);
set_sm_h(#client{module = Mod}, _) ->
error({set_sm_h, {undefined_for_escalus_module, Mod}}).

Expand Down
8 changes: 4 additions & 4 deletions src/escalus_tcp.erl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ separate_ack_requests({true, H0, inactive}, Stanzas) ->
Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)],
Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)],

case {length(Enabled),length(Resumed)} of
case {length(Enabled), length(Resumed)} of
%% Enabled SM: set the H param to 0 and activate counter.
{1,0} -> {{true, 0, active}, [], Stanzas};

Expand All @@ -450,12 +450,12 @@ separate_ack_requests({true, H0, active}, Stanzas) ->

make_ack(H) -> {escalus_stanza:sm_ack(H), H}.

reply_to_ack_requests({false,H,A}, _, _) -> {false, H, A};
reply_to_ack_requests({true,H,inactive}, _, _) -> {true, H, inactive};
reply_to_ack_requests({false, H, A}, _, _) -> {false, H, A};
reply_to_ack_requests({true, H, inactive}, _, _) -> {true, H, inactive};
reply_to_ack_requests({true, H0, active}, Acks, State) ->
{true,
% TODO: Maybe compress here?
lists:foldl(fun({Ack,H}, _) -> raw_send(exml:to_iolist(Ack), State), H end,
lists:foldl(fun({Ack, H}, _) -> raw_send(exml:to_iolist(Ack), State), H end,
H0, Acks),
active}.

Expand Down
209 changes: 150 additions & 59 deletions src/escalus_ws.erl
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
send/2,
is_connected/1,
reset_parser/1,
get_sm_h/1,
set_sm_h/2,
use_zlib/1,
upgrade_to_tls/2,
set_filter_predicate/2,
Expand All @@ -40,16 +42,32 @@
-define(SERVER, ?MODULE).

-record(state, {owner, socket, parser, legacy_ws, compress = false,
event_client, filter_pred, stream_ref}).
event_client, sm_state, filter_pred, stream_ref, sent_stanzas = []}).
-type state() :: #state{}.

-type sm_state() :: {boolean(), non_neg_integer(), 'active'|'inactive'}.

-type opts() :: #{
host => string(),
port => pos_integer(),
wspath => string(),
wslegacy => boolean(),
event_client => undefined | escalus_event:event_client(),
ssl => boolean(),
ssl_opts => [ssl:ssl_option()],
ws_upgrade_timeout => pos_integer(),
stream_management => boolean(),
manual_ack => boolean()
}.

%%%===================================================================
%%% API
%%%===================================================================

-spec connect([proplists:property()]) -> pid().
connect(Args) ->
{ok, Pid} = gen_server:start_link(?MODULE, [Args, self()], []),
connect(Opts0) ->
Opts1 = opts_to_map(Opts0),
{ok, Pid} = gen_server:start_link(?MODULE, [Opts1, self()], []),
Pid.

-spec send(pid(), exml:element()) -> ok.
Expand All @@ -60,6 +78,15 @@ send(Pid, Elem) ->
is_connected(Pid) ->
erlang:is_process_alive(Pid).

-spec get_sm_h(pid()) -> non_neg_integer().
get_sm_h(Pid) ->
gen_server:call(Pid, get_sm_h).

-spec set_sm_h(pid(), non_neg_integer()) -> {ok, non_neg_integer()}.
set_sm_h(Pid, H) ->
gen_server:call(Pid, {set_sm_h, H}).


-spec reset_parser(pid()) -> ok.
reset_parser(Pid) ->
gen_server:cast(Pid, reset_parser).
Expand Down Expand Up @@ -151,35 +178,42 @@ assert_stream_end(StreamEndRep, Props) ->
error("Not a valid stream end", [StreamEndRep])
end.

%%%===================================================================
%%% Default options
%%%===================================================================

default_options() ->
#{host => "localhost",
port => 5280,
wspath => "/ws-xmpp",
wslegacy => false,
event_client => undefined,
ssl => false,
ssl_opts => [],
ws_upgrade_timeout => 5000,
stream_management => false,
manual_ack => false}.

%%%===================================================================
%%% gen_server callbacks
%%%===================================================================

%% TODO: refactor all opt defaults taken from Args into a default_opts function,
%% so that we know what options the module actually expects
-spec init(list()) -> {ok, state()}.
init([Args, Owner]) ->
Host = get_host(Args, "localhost"),
Port = get_port(Args, 5280),
Resource = get_resource(Args, "/ws-xmpp"),
LegacyWS = get_legacy_ws(Args, false),
EventClient = proplists:get_value(event_client, Args),
SSL = proplists:get_value(ssl, Args, false),
SSLOpts = proplists:get_value(ssl_opts, Args, []),
init([Opts, Owner]) ->
Opts1 = overwrite_default_opts(Opts, default_options()),
#{wspath := Resource,
wslegacy := LegacyWS,
event_client := EventClient,
ws_upgrade_timeout := Timeout} = Opts1,
SM = get_stream_management_opt(Opts1),
Resource1 = maybe_binary_to_list(Resource),

ConnPid = do_connect(Opts1),

%% Disable http2 in protocols
TransportOpts = case SSL of
true ->
#{transport => tls, protocols => [http],
tls_opts => SSLOpts};
_ ->
#{transport => tcp, protocols => [http]}
end,
{ok, ConnPid} = gun:open(Host, Port, TransportOpts),
{ok, http} = gun:await_up(ConnPid),
WSUpgradeHeaders = [{<<"sec-websocket-protocol">>, <<"xmpp">>}],
StreamRef = gun:ws_upgrade(ConnPid, Resource, WSUpgradeHeaders,
StreamRef = gun:ws_upgrade(ConnPid, Resource1, WSUpgradeHeaders,
#{protocols => [{<<"xmpp">>, gun_ws_h}]}),
Timeout = get_option(ws_upgrade_timeout, Args, 5000),
wait_for_ws_upgrade(ConnPid, StreamRef, Timeout),
ParserOpts = case LegacyWS of
true -> [];
Expand All @@ -190,6 +224,7 @@ init([Args, Owner]) ->
socket = ConnPid,
parser = Parser,
legacy_ws = LegacyWS,
sm_state = SM,
event_client = EventClient,
stream_ref = StreamRef}}.

Expand All @@ -208,6 +243,11 @@ wait_for_ws_upgrade(ConnPid, StreamRef, Timeout) ->

-spec handle_call(term(), {pid(), term()}, state()) ->
{reply, term(), state()} | {stop, normal, ok, state()}.
handle_call(get_sm_h, _From, #state{sm_state = {_, H, _}} = State) ->
{reply, H, State};
handle_call({set_sm_h, H}, _From, #state{sm_state = {A, _OldH, S}} = State) ->
NewState = State#state{sm_state={A, H, S}},
{reply, {ok, H}, NewState};
handle_call(use_zlib, _, #state{parser = Parser} = State) ->
Zin = zlib:open(),
Zout = zlib:open(),
Expand Down Expand Up @@ -259,6 +299,30 @@ code_change(_OldVsn, State, _Extra) ->
%%% Helpers
%%%===================================================================

-spec get_stream_management_opt(opts()) -> sm_state().
get_stream_management_opt(#{stream_management := false}) ->
{false, 0, inactive};
get_stream_management_opt(#{manual_ack := true}) ->
{false, 0, inactive};
get_stream_management_opt(#{stream_management := true, manual_ack := false}) ->
{true, 0, inactive}.

overwrite_default_opts(GivenOpts, DefaultOpts) ->
maps:merge(DefaultOpts, GivenOpts).

do_connect(#{ssl := true, ssl_opts := SSLOpts} = Opts) ->
TransportOpts = #{transport => tls, protocols => [http],
tls_opts => SSLOpts},
chrzaszcz marked this conversation as resolved.
Show resolved Hide resolved
do_connect(Opts, TransportOpts);
do_connect(Opts) ->
do_connect(Opts, #{transport => tcp, protocols => [http]}).

do_connect(#{host := Host, port := Port}, TransportOpts) ->
Host1 = maybe_binary_to_list(Host),
{ok, ConnPid} = gun:open(Host1, Port, TransportOpts),
{ok, http} = gun:await_up(ConnPid),
ConnPid.

handle_data(Data, State = #state{parser = Parser,
compress = Compress}) ->
Timestamp = os:system_time(micro_seconds),
Expand All @@ -270,11 +334,11 @@ handle_data(Data, State = #state{parser = Parser,
Decompressed = iolist_to_binary(zlib:inflate(Zin, Data)),
exml_stream:parse(Parser, Decompressed)
end,
NewState = State#state{parser = NewParser},
escalus_connection:maybe_forward_to_owner(NewState#state.filter_pred,
NewState,
Stanzas,
fun forward_to_owner/3, Timestamp),
FwdState = State#state{parser = NewParser, sent_stanzas = []},
NewState = escalus_connection:maybe_forward_to_owner(FwdState#state.filter_pred,
FwdState,
Stanzas,
fun forward_to_owner/3, Timestamp),
case lists:filter(fun(Stanza) -> is_stream_end(Stanza, State) end, Stanzas) of
[] -> {noreply, NewState};
_ -> {stop, normal, NewState}
Expand All @@ -285,44 +349,69 @@ is_stream_end(#xmlstreamend{}, #state{legacy_ws = true}) -> true;
is_stream_end(#xmlel{name = <<"close">>}, #state{legacy_ws = false}) -> true;
is_stream_end(_, _) -> false.

forward_to_owner(Stanzas, #state{owner = Owner,
event_client = EventClient}, Timestamp) ->
forward_to_owner(Stanzas0, #state{owner = Owner,
sm_state = SM0,
event_client = EventClient} = State, Timestamp) ->
chrzaszcz marked this conversation as resolved.
Show resolved Hide resolved
{SM1, AckRequests, StanzasNoRs} = separate_ack_requests(SM0, Stanzas0),
reply_to_ack_requests(SM1, AckRequests, State),

lists:foreach(fun(Stanza) ->
escalus_event:incoming_stanza(EventClient, Stanza),
Owner ! escalus_connection:stanza_msg(Stanza,
#{recv_timestamp => Timestamp})
end, Stanzas).
escalus_event:incoming_stanza(EventClient, Stanza),
Owner ! escalus_connection:stanza_msg(Stanza, #{recv_timestamp => Timestamp})
end, StanzasNoRs),

State#state{sm_state = SM1, sent_stanzas = StanzasNoRs}.

separate_ack_requests({false, H0, A}, Stanzas) ->
chrzaszcz marked this conversation as resolved.
Show resolved Hide resolved
%% Don't keep track of H
{{false, H0, A}, [], Stanzas};
separate_ack_requests({true, H0, inactive}, Stanzas) ->
Enabled = [ S || S <- Stanzas, escalus_pred:is_sm_enabled(S)],
Resumed = [ S || S <- Stanzas, escalus_pred:is_sm_resumed(S)],

case {length(Enabled), length(Resumed)} of
%% Enabled SM: set the H param to 0 and activate counter.
{1,0} -> {{true, 0, active}, [], Stanzas};

%% Resumed SM: keep the H param and activate counter.
{0,1} -> {{true, H0, active}, [], Stanzas};

%% No new SM state: continue as usual
{0,0} -> {{true, H0, inactive}, [], Stanzas}
end;
separate_ack_requests({true, H0, active}, Stanzas) ->
%% Count H and construct appropriate acks
F = fun(Stanza, {H, Acks, NonAckRequests}) ->
case escalus_pred:is_sm_ack_request(Stanza) of
true -> {H, [make_ack(H)|Acks], NonAckRequests};
false -> {H+1, Acks, [Stanza|NonAckRequests]}
end
end,
{H, Acks, Others} = lists:foldl(F, {H0, [], []}, Stanzas),
{{true, H, active}, lists:reverse(Acks), lists:reverse(Others)}.

make_ack(H) -> {escalus_stanza:sm_ack(H), H}.

reply_to_ack_requests({false, H, A}, _, _) ->
{false, H, A};
reply_to_ack_requests({true, H, inactive}, _, _) ->
{true, H, inactive};
reply_to_ack_requests({true, H0, active}, Acks, State) ->
{true,
lists:foldl(fun({Ack, H}, _) ->
Ack1 = exml:to_iolist(Ack),
gun:ws_send(State#state.socket, State#state.stream_ref, {text, Ack1}),
H
end, H0, Acks),
active}.

common_terminate(_Reason, #state{parser = Parser}) ->
exml_stream:free_parser(Parser).

-spec get_port(list(), inet:port_number()) -> inet:port_number().
get_port(Args, Default) ->
get_option(port, Args, Default).

-spec get_host(list(), string()) -> string().
get_host(Args, Default) ->
maybe_binary_to_list(get_option(host, Args, Default)).

-spec get_resource(list(), string()) -> string().
get_resource(Args, Default) ->
maybe_binary_to_list(get_option(wspath, Args, Default)).

-spec get_legacy_ws(list(), boolean()) -> boolean().
get_legacy_ws(Args, Default) ->
get_option(wslegacy, Args, Default).

-spec maybe_binary_to_list(binary() | string()) -> string().
maybe_binary_to_list(B) when is_binary(B) -> binary_to_list(B);
maybe_binary_to_list(S) when is_list(S) -> S.

-spec get_option(any(), list(), any()) -> any().
get_option(Key, Opts, Default) ->
case lists:keyfind(Key, 1, Opts) of
false -> Default;
{Key, Value} -> Value
end.

close_compression_streams(false) ->
ok;
close_compression_streams({zlib, {Zin, Zout}}) ->
Expand All @@ -337,4 +426,6 @@ close_compression_streams({zlib, {Zin, Zout}}) ->
ok = zlib:close(Zout)
end.


-spec opts_to_map([proplists:property()] | opts()) -> opts().
chrzaszcz marked this conversation as resolved.
Show resolved Hide resolved
opts_to_map(Opts) when is_map(Opts) -> Opts;
opts_to_map(Opts) when is_list(Opts) -> maps:from_list(Opts).
Loading