Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes required/useful by/for AMOC #62

Merged
merged 13 commits into from
Apr 28, 2015
23 changes: 18 additions & 5 deletions src/escalus_bosh.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

-module(escalus_bosh).
-behaviour(gen_server).
-behaviour(escalus_connection).

-include_lib("exml/include/exml_stream.hrl").
-include("include/escalus.hrl").
-include("include/escalus_xmlns.hrl").
-include("escalus.hrl").
-include("escalus_xmlns.hrl").
-include("no_binary_to_integer.hrl").

%% Escalus transport callbacks
Expand All @@ -21,7 +22,8 @@
get_transport/1,
reset_parser/1,
stop/1,
kill/1]).
kill/1,
set_filter_predicate/2]).

%% gen_server callbacks
-export([init/1,
Expand Down Expand Up @@ -67,7 +69,8 @@
terminated = false,
event_client,
client,
on_reply}).
on_reply,
filter_pred}).

%%%===================================================================
%%% API
Expand Down Expand Up @@ -109,6 +112,11 @@ use_zlib(#client{} = _Conn, _Props) ->
get_transport(#client{rcv_pid = Pid}) ->
gen_server:call(Pid, get_transport).

-spec set_filter_predicate(escalus_connection:client(),
escalus_connection:filter_pred()) -> ok.
set_filter_predicate(#client{rcv_pid = Pid}, Pred) ->
gen_server:call(Pid, {set_filter_pred, Pred}).

%%%===================================================================
%%% BOSH XML elements
%%%===================================================================
Expand Down Expand Up @@ -295,6 +303,9 @@ handle_call(recv, _From, State) ->
handle_call(get_requests, _From, State) ->
{reply, length(State#state.requests), State};

handle_call({set_filter_pred, Pred}, _From, State) ->
{reply, ok, State#state{filter_pred = Pred}};

handle_call(stop, _From, #state{} = State) ->
StreamEnd = escalus_stanza:stream_end(),
{ok, _Reply, NewState} =
Expand Down Expand Up @@ -407,7 +418,9 @@ handle_data(#xmlel{} = Body, #state{} = State) ->
Stanzas = unwrap_elem(Body),
case State#state.active of
true ->
forward_to_owner(Stanzas, NewState),
escalus_connection:maybe_forward_to_owner(NewState#state.filter_pred,
NewState, Stanzas,
fun forward_to_owner/2),
NewState;
false ->
store_reply(Body, NewState)
Expand Down
52 changes: 50 additions & 2 deletions src/escalus_connection.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
-module(escalus_connection).

-include_lib("exml/include/exml_stream.hrl").
-include("include/escalus.hrl").
-include("escalus.hrl").

%% High-level API
-export([start/1, start/2,
Expand All @@ -16,24 +16,45 @@
-export([connect/1,
send/2,
get_stanza/2,
get_stanza/3,
get_sm_h/1,
set_sm_h/2,
set_filter_predicate/2,
reset_parser/1,
is_connected/1,
kill/1]).

%% Behaviour helpers
-export([maybe_forward_to_owner/4]).

%% Public Types
-type client() :: #client{}.
-export_type([client/0]).

-type step_spec() :: atom() | {module(), atom()} | escalus_session:step().
-export_type([step_spec/0]).

-type filter_pred() :: fun((#xmlel{}) -> boolean()) | none.
-export_type([filter_pred/0]).

%% Private
-export([connection_step/2]).

-define(TIMEOUT, 1000).

%%%===================================================================
%%% Behaviour callback
%%%===================================================================
-callback connect([proplists:property()]) -> {ok, client()}.
-callback send(client(), #xmlel{}) -> no_return().
-callback stop(client()) -> ok | already_stopped.

-callback is_connected(client()) -> boolean().
-callback reset_parser(client()) -> no_return().
-callback kill(client()) -> no_return().
-callback set_filter_predicate(client(), filter_pred()) -> ok.


%%%===================================================================
%%% Public API
%%%===================================================================
Expand Down Expand Up @@ -130,10 +151,14 @@ send(#client{module = Mod, event_client = EventClient} = Client, Elem) ->

-spec get_stanza(client(), any()) -> #xmlel{}.
get_stanza(Conn, Name) ->
get_stanza(Conn, Name, ?TIMEOUT).

-spec get_stanza(client(), any(), timeout()) -> #xmlel{}.
get_stanza(Conn, Name, Timeout) ->
receive
{stanza, Conn, Stanza} ->
Stanza
after ?TIMEOUT ->
after Timeout ->
throw({timeout, Name})
end.

Expand All @@ -149,6 +174,10 @@ set_sm_h(#client{module = escalus_tcp} = Conn, H) ->
set_sm_h(#client{module = Mod}, _) ->
error({set_sm_h, {undefined_for_escalus_module, Mod}}).

-spec set_filter_predicate(client(), filter_pred()) -> ok.
set_filter_predicate(#client{module = Module} = Conn, Pred) ->
Module:set_filter_predicate(Conn, Pred).

reset_parser(#client{module = Mod} = Client) ->
Mod:reset_parser(Client).

Expand All @@ -162,6 +191,25 @@ stop(#client{module = Mod} = Client) ->
kill(#client{module = Mod} = Client) ->
Mod:kill(Client).


-spec maybe_forward_to_owner(filter_pred(), term(), [#xmlel{}],
fun(([#xmlel{}], term()) -> term()))
-> term().
maybe_forward_to_owner(none, State, _Stanzas, _Fun) ->
State;
maybe_forward_to_owner(FilterPred, State, Stanzas, Fun)
when is_function(FilterPred) ->
AllowedStanzas = lists:filter(FilterPred, Stanzas),
case AllowedStanzas of
[] ->
State;
_ ->
Fun(AllowedStanzas, State)
end;
maybe_forward_to_owner(_, State, Stanzas, Fun) ->
Fun(Stanzas, State).


%%%===================================================================
%%% Helpers
%%%===================================================================
Expand Down
24 changes: 16 additions & 8 deletions src/escalus_stanza.erl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@

-export([disco_info/1,
disco_info/2,
disco_items/1
disco_items/1,
disco_items/2
]).

-export([vcard_update/1,
Expand Down Expand Up @@ -505,6 +506,9 @@ disco_info(JID, Node) ->
disco_items(JID) ->
ItemsQuery = query_el(?NS_DISCO_ITEMS, []),
iq(JID, <<"get">>, [ItemsQuery]).
disco_items(JID, Node) ->
ItemsQuery = query_el(?NS_DISCO_ITEMS, [{<<"node">>, Node}], []),
iq(JID, <<"get">>, [ItemsQuery]).

search_fields([]) ->
[];
Expand Down Expand Up @@ -647,9 +651,9 @@ mam_lookup_messages_iq(QueryId, Start, End, WithJID) ->
%% Include an rsm id for a particular message.
mam_lookup_messages_iq(QueryId, Start, End, WithJID, DirectionWMessageId) ->
IQ = #xmlel{children=[Q]} = mam_lookup_messages_iq(QueryId, Start, End, WithJID),
Q2 = Q#xmlel{children = defined([
fmapM(fun rsm_after_or_before/1, DirectionWMessageId)
])},
RSM = defined([fmapM(fun rsm_after_or_before/1, DirectionWMessageId)]),
Other = Q#xmlel.children,
Q2 = Q#xmlel{children = Other ++ RSM},
IQ#xmlel{children=[Q2]}.

fmapM(_F, undefined) -> undefined;
Expand All @@ -666,18 +670,22 @@ end_elem(EndTime) ->
with_elem(BWithJID) ->
#xmlel{name = <<"with">>, children = #xmlcdata{content = BWithJID}}.

rsm_after_or_before({Direction, AbstractID}) when is_binary(AbstractID) ->
rsm_after_or_before({Direction, AbstractID, MaxCount}) ->
#xmlel{name = <<"set">>,
attrs = [{<<"xmlns">>, ?NS_RSM}],
children = [ direction_el(Direction, AbstractID) ]}.
children = defined([max(MaxCount), direction_el(Direction, AbstractID) ])}.

direction_el('after', AbstractID) when is_binary(AbstractID) ->
#xmlel{name = <<"after">>, children = #xmlcdata{content = AbstractID}};
direction_el('before', AbstractID) when is_binary(AbstractID) ->
#xmlel{name = <<"before">>, children = #xmlcdata{content = AbstractID}}.
#xmlel{name = <<"before">>, children = #xmlcdata{content = AbstractID}};
direction_el(_, undefined) ->
undefined.

max(N) when is_integer(N) ->
#xmlel{name = <<"max">>, children = #xmlcdata{content = integer_to_binary(N)}}.
#xmlel{name = <<"max">>, children = #xmlcdata{content = integer_to_binary(N)}};
max(_) ->
undefined.

mam_ns_attr() -> {<<"xmlns">>,?NS_MAM}.

Expand Down
41 changes: 32 additions & 9 deletions src/escalus_tcp.erl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

-module(escalus_tcp).
-behaviour(gen_server).
-behaviour(escalus_connection).

-include_lib("exml/include/exml_stream.hrl").
-include("escalus.hrl").
Expand All @@ -20,6 +21,7 @@
reset_parser/1,
get_sm_h/1,
set_sm_h/2,
set_filter_predicate/2,
stop/1,
kill/1]).

Expand All @@ -46,6 +48,7 @@
-record(state, {owner,
socket,
parser,
filter_pred,
ssl = false,
compress = false,
event_client,
Expand All @@ -55,7 +58,6 @@
sm_state = {true, 0, inactive} :: sm_state(),
replies = []}).


%%%===================================================================
%%% API
%%%===================================================================
Expand All @@ -81,6 +83,11 @@ get_sm_h(#client{rcv_pid = Pid}) ->
set_sm_h(#client{rcv_pid = Pid}, H) ->
gen_server:call(Pid, {set_sm_h, H}).

-spec set_filter_predicate(escalus_connection:client(),
escalus_connection:filter_pred()) -> ok.
set_filter_predicate(#client{rcv_pid = Pid}, Pred) ->
gen_server:call(Pid, {set_filter_pred, Pred}).

stop(#client{rcv_pid = Pid}) ->
try
gen_server:call(Pid, stop)
Expand Down Expand Up @@ -140,7 +147,11 @@ recv(#client{rcv_pid = Pid}) ->
init([Args, Owner]) ->
Host = proplists:get_value(host, Args, <<"localhost">>),
Port = proplists:get_value(port, Args, 5222),

Address = host_to_inet(Host),
EventClient = proplists:get_value(event_client, Args),
Interface = proplists:get_value(iface, Args),
IsSSLConnection = proplists:get_value(ssl, Args, false),

OnReplyFun = proplists:get_value(on_reply, Args, fun(_) -> ok end),
OnRequestFun = proplists:get_value(on_request, Args, fun(_) -> ok end),
Expand All @@ -154,9 +165,15 @@ init([Args, Owner]) ->
{true,false} -> {true, 0, inactive}
end,

Address = host_to_inet(Host),
IsSSLConnection = proplists:get_value(ssl, Args, false),
{ok, Socket} = do_connect(IsSSLConnection, Address, Port, Args, OnConnectFun),

BasicOpts = [binary, {active, once}],
SocketOpts = case Interface of
undefined -> BasicOpts;
_ -> [{ip, iface_to_ip_address(Interface)}] ++ BasicOpts
end,

{ok, Socket} = do_connect(IsSSLConnection, Address, Port, Args,
SocketOpts, OnConnectFun),
{ok, Parser} = exml_stream:new_parser(),
{ok, #state{owner = Owner,
socket = Socket,
Expand Down Expand Up @@ -198,6 +215,8 @@ handle_call(get_active, _From, #state{active = Active} = State) ->
{reply, Active, State};
handle_call({set_active, Active}, _From, State) ->
{reply, ok, State#state{active = Active}};
handle_call({set_filter_pred, Pred}, _From, State) ->
{reply, ok, State#state{filter_pred = Pred}};
handle_call(recv, _From, State) ->
{Reply, NS} = handle_recv(State),
{reply, Reply, NS};
Expand Down Expand Up @@ -272,11 +291,14 @@ handle_data(Socket, Data, #state{parser = Parser,
NewState = State#state{parser = NewParser},
case State#state.active of
true ->
forward_to_owner(Stanzas, NewState);
escalus_connection:maybe_forward_to_owner(NewState#state.filter_pred,
NewState, Stanzas,
fun forward_to_owner/2);
false ->
store_reply(Stanzas, NewState)
end.


forward_to_owner(Stanzas0, #state{owner = Owner,
sm_state = SM0,
event_client = EventClient} = State) ->
Expand Down Expand Up @@ -310,7 +332,6 @@ handle_recv(#state{replies = [Reply | Replies]} = S) ->
end,
{Reply, S#state{replies = Replies}}.


separate_ack_requests({false, H0, A}, Stanzas) ->
%% Don't keep track of H
{{false, H0, A}, [], Stanzas};
Expand Down Expand Up @@ -385,6 +406,9 @@ host_to_inet({_,_,_,_,_,_,_,_} = IP6) -> IP6;
host_to_inet(Address) when is_list(Address) orelse is_atom(Address) -> Address;
host_to_inet(BAddress) when is_binary(BAddress) -> binary_to_list(BAddress).

iface_to_ip_address({_,_,_,_} = IP4) -> IP4;
iface_to_ip_address({_,_,_,_,_,_,_,_} = IP6) -> IP6.

close_compression_streams(false) ->
ok;
close_compression_streams({zlib, {Zin, Zout}}) ->
Expand All @@ -411,10 +435,9 @@ send_stream_end(#state{socket = Socket, ssl = Ssl, compress = Compress}) ->
gen_tcp:send(Socket, exml:to_iolist(StreamEnd))
end.

do_connect(IsSSLConnection, Address, Port, Args, OnConnectFun) ->
Opts = [binary, {active, once}],
do_connect(IsSSLConnection, Address, Port, Args, SocketOpts, OnConnectFun) ->
TimeB = os:timestamp(),
Reply = maybe_ssl_connection(IsSSLConnection, Address, Port, Opts, Args),
Reply = maybe_ssl_connection(IsSSLConnection, Address, Port, SocketOpts, Args),
TimeA = os:timestamp(),
ConnectionTime = timer:now_diff(TimeA, TimeB),
case Reply of
Expand Down
Loading