From 86f779520574d03d5be3f22d811d9348aa05a6b9 Mon Sep 17 00:00:00 2001 From: Kamil Waz Date: Tue, 25 Oct 2022 09:23:43 +0200 Subject: [PATCH] undo 2 --- src/c2s/mongoose_c2s.erl | 6 +- src/ejabberd_socket.erl | 16 - src/mod_bosh_socket.erl | 2 +- .../mongoose_system_metrics_collector.erl | 4 +- test/acc_SUITE.erl | 33 + test/ejabberd_sm_SUITE.erl | 615 ++++++++++++++++++ 6 files changed, 654 insertions(+), 22 deletions(-) create mode 100644 test/ejabberd_sm_SUITE.erl diff --git a/src/c2s/mongoose_c2s.erl b/src/c2s/mongoose_c2s.erl index 6223fe5946..f8a942adeb 100644 --- a/src/c2s/mongoose_c2s.erl +++ b/src/c2s/mongoose_c2s.erl @@ -371,7 +371,7 @@ handle_starttls(StateData = #c2s_data{socket = TcpSocket, handle_auth_start(StateData, El, SaslState, Retries) -> case {mongoose_c2s_socket:is_ssl(StateData#c2s_data.socket), StateData#c2s_data.listener_opts} of {false, #{tls := #{mode := starttls_required}}} -> - Error = mongoose_xmpp_errors:policy_violation( + Error = mongoose_xmpp_errors:policy_violation( StateData#c2s_data.lang, <<"Use of STARTTLS required">>), c2s_stream_error(StateData, Error); _ -> @@ -547,8 +547,8 @@ maybe_retry_state(StateData = #c2s_data{listener_opts = LOpts}, C2SState) -> -spec handle_cast(c2s_data(), c2s_state(), term()) -> fsm_res(). handle_cast(StateData, _C2SState, {exit, Reason}) when is_binary(Reason) -> - Error = mongoose_xmpp_errors:stream_conflict(StateData#c2s_data.lang, Reason), - send_element_from_server_jid(StateData, Error), + StreamConflict = mongoose_xmpp_errors:stream_conflict(StateData#c2s_data.lang, Reason), + send_element_from_server_jid(StateData, StreamConflict), send_trailer(StateData), {stop, {shutdown, Reason}}; handle_cast(StateData, _C2SState, {exit, system_shutdown}) -> diff --git a/src/ejabberd_socket.erl b/src/ejabberd_socket.erl index 7320191c04..9a85923ada 100644 --- a/src/ejabberd_socket.erl +++ b/src/ejabberd_socket.erl @@ -303,19 +303,3 @@ peername(#socket_state{sockmod = SockMod, socket = Socket}) -> -spec get_socket(socket_state()) -> term(). get_socket(#socket_state{socket = Socket}) -> Socket. - - -% format_socket(#socket_state{sockmod = Mod, socket = Socket, -% receiver = Receiver, connection_details = Info}) -> -% Info2 = format_details(Info), -% Info2#{socket_module => Mod, -% socket => format_term(Socket), -% receiver => format_term(Receiver)}; -% format_socket(_) -> -% #{}. - -% format_term(X) -> iolist_to_binary(io_lib:format("~0p", [X])). - -% format_details(Info = #{dest_address := DestAddr, src_address := SrcAddr}) -> -% Info#{dest_address => inet:ntoa(DestAddr), -% src_address => inet:ntoa(SrcAddr)}. diff --git a/src/mod_bosh_socket.erl b/src/mod_bosh_socket.erl index 40db379a25..eb856157a9 100644 --- a/src/mod_bosh_socket.erl +++ b/src/mod_bosh_socket.erl @@ -435,7 +435,7 @@ handle_info(Info, SName, State) -> terminate(Reason, StateName, #state{sid = Sid, handlers = Handlers} = S) -> [Pid ! {close, Sid} || {_, _, Pid} <- lists:sort(Handlers)], mod_bosh_backend:delete_session(Sid), - catch ejabberd_c2s:stop(S#state.c2s_pid), + catch mongoose_c2s:stop(S#state.c2s_pid, normal), ?LOG_DEBUG(ls(#{what => bosh_socket_closing_session, reason => Reason, state_name => StateName, handlers => Handlers, pending => S#state.pending}, S)). diff --git a/src/system_metrics/mongoose_system_metrics_collector.erl b/src/system_metrics/mongoose_system_metrics_collector.erl index 4cca5f7e82..f643d42139 100644 --- a/src/system_metrics/mongoose_system_metrics_collector.erl +++ b/src/system_metrics/mongoose_system_metrics_collector.erl @@ -132,7 +132,7 @@ filter_unknown_api(ApiList) -> get_transport_mechanisms() -> HTTP = [Mod || Mod <- get_http_handler_modules(), Mod =:= mod_bosh orelse Mod =:= mod_websockets], - TCP = lists:usort([tcp || #{proto := tcp} <- get_listeners(ejabberd_c2s)]), + TCP = lists:usort([tcp || #{proto := tcp} <- get_listeners(mongoose_c2s_listener)]), [#{report_name => transport_mechanism, key => Transport, value => enabled} || Transport <- HTTP ++ TCP]. @@ -149,7 +149,7 @@ get_http_handler_modules(#{handlers := Handlers}) -> [Module || #{module := Module} <- Handlers]. get_tls_options() -> - TLSOptions = lists:flatmap(fun extract_tls_options/1, get_listeners(ejabberd_c2s)), + TLSOptions = lists:flatmap(fun extract_tls_options/1, get_listeners(mongoose_c2s_listener)), [#{report_name => tls_option, key => TLSMode, value => TLSModule} || {TLSMode, TLSModule} <- lists:usort(TLSOptions)]. diff --git a/test/acc_SUITE.erl b/test/acc_SUITE.erl index 3d83b9347f..7449cdc933 100644 --- a/test/acc_SUITE.erl +++ b/test/acc_SUITE.erl @@ -30,6 +30,7 @@ groups() -> store_retrieve_and_delete_many, init_from_element, produce_iq_meta_automatically, + strip, strip_with_params, parse_with_cdata ] @@ -129,6 +130,38 @@ parse_with_cdata(_C) -> {XMLNS, _} = mongoose_iq:xmlns(Acc), ?assertEqual(<<"jabber:iq:roster">>, XMLNS). +strip(_C) -> + El = iq_stanza(), + FromJID = jid:make(<<"jajid">>, <<"localhost">>, <<>>), + ToJID = jid:make(<<"tyjid">>, <<"localhost">>, <<>>), + Server = maps:get(lserver, ?ACC_PARAMS), + HostType = maps:get(host_type, ?ACC_PARAMS), + Acc = mongoose_acc:new(?ACC_PARAMS#{element => El, + from_jid => FromJID, + to_jid => ToJID}), + {XMLNS1, Acc1} = mongoose_iq:xmlns(Acc), + ?assertEqual(<<"urn:ietf:params:xml:ns:xmpp-session">>, XMLNS1), + ?assertEqual(<<"set">>, mongoose_acc:stanza_type(Acc1)), + ?assertEqual(undefined, mongoose_acc:get(ns, ppp, undefined, Acc1)), + Acc2 = mongoose_acc:set_permanent(ns, ppp, 997, Acc1), + Acc3 = mongoose_acc:set(ns2, [{a, 1}, {b, 2}], Acc2), + ?assertMatch([_, _], mongoose_acc:get(ns2, Acc3)), + ?assertEqual(Server, mongoose_acc:lserver(Acc3)), + ?assertEqual(HostType, mongoose_acc:host_type(Acc3)), + ?assertEqual({FromJID, ToJID, El}, mongoose_acc:packet(Acc3)), + Ref = mongoose_acc:ref(Acc3), + %% strip stanza and check that only non-permanent fields are missing + NAcc1 = mongoose_acc:strip(Acc3), + {XMLNS3, NAcc2} = mongoose_iq:xmlns(NAcc1), + ?assertEqual(<<"urn:ietf:params:xml:ns:xmpp-session">>, XMLNS3), + ?assertEqual(<<"set">>, mongoose_acc:stanza_type(NAcc2)), + ?assertEqual(Server, mongoose_acc:lserver(NAcc2)), + ?assertEqual(HostType, mongoose_acc:host_type(NAcc2)), + ?assertEqual({FromJID, ToJID, El}, mongoose_acc:packet(NAcc2)), + ?assertEqual(Ref, mongoose_acc:ref(NAcc2)), + ?assertEqual(997, mongoose_acc:get(ns, ppp, NAcc2)), + ?assertEqual([], mongoose_acc:get(ns2, NAcc2)). + strip_with_params(_Config) -> FromJID = jid:make(<<"jajid">>, <<"localhost">>, <<>>), ToJID = jid:make(<<"tyjid">>, <<"localhost">>, <<>>), diff --git a/test/ejabberd_sm_SUITE.erl b/test/ejabberd_sm_SUITE.erl new file mode 100644 index 0000000000..aa85adcfa8 --- /dev/null +++ b/test/ejabberd_sm_SUITE.erl @@ -0,0 +1,615 @@ +-module(ejabberd_sm_SUITE). +-include_lib("eunit/include/eunit.hrl"). +-include_lib("common_test/include/ct.hrl"). + +-include_lib("jid/include/jid.hrl"). +-include_lib("session.hrl"). +-compile([export_all, nowarn_export_all]). + +-define(eq(E, I), ?assertEqual(E, I)). + +-define(B(C), (proplists:get_value(backend, C))). +-define(MAX_USER_SESSIONS, 2). + +-import(config_parser_helper, [default_config/1]). + +all() -> [{group, mnesia}, {group, redis}]. + +init_per_suite(C) -> + {ok, _} = application:ensure_all_started(jid), + application:ensure_all_started(exometer_core), + F = fun() -> + ejabberd_sm_backend_sup:start_link(), + receive stop -> ok end + end, + Pid = spawn(F), + [{pid, Pid} | C]. + +end_per_suite(C) -> + Pid = ?config(pid, C), + Pid ! stop, + application:stop(exometer), + application:stop(exometer_core). + +groups() -> + [{mnesia, [], tests()}, + {redis, [], tests()}]. + +tests() -> + [open_session, + get_full_session_list, + get_vh_session_list, + get_sessions_2, + get_sessions_3, + session_is_updated_when_created_twice, + delete_session, + clean_up, + too_much_sessions, + unique_count, + unique_count_while_removing_entries, + session_info_is_stored, + session_info_is_updated_if_keys_match, + session_info_is_extended_if_new_keys_present, + session_info_keys_not_truncated_if_session_opened_with_empty_infolist, + kv_can_be_stored_for_session, + kv_can_be_updated_for_session, + kv_can_be_removed_for_session, + store_info_sends_message_to_the_session_owner, + remove_info_sends_message_to_the_session_owner, + cannot_reproduce_race_condition_in_store_info + ]. + +init_per_group(mnesia, Config) -> + ok = mnesia:create_schema([node()]), + ok = mnesia:start(), + [{backend, ejabberd_sm_mnesia} | Config]; +init_per_group(redis, Config) -> + init_redis_group(is_redis_running(), Config). + +init_redis_group(true, Config) -> + Self = self(), + proc_lib:spawn(fun() -> + register(test_helper, self()), + mongoose_wpool:ensure_started(), + % This would be started via outgoing_pools in normal case + Pool = default_config([outgoing_pools, redis, default]), + mongoose_wpool:start_configured_pools([Pool], []), + Self ! ready, + receive stop -> ok end + end), + receive ready -> ok after timer:seconds(30) -> ct:fail(test_helper_not_ready) end, + [{backend, ejabberd_sm_redis} | Config]; +init_redis_group(_, _) -> + {skip, "redis not running"}. + +end_per_group(mnesia, Config) -> + mnesia:stop(), + mnesia:delete_schema([node()]), + Config; +end_per_group(_, Config) -> + whereis(test_helper) ! stop, + Config. + +init_per_testcase(too_much_sessions, Config) -> + set_test_case_meck(?MAX_USER_SESSIONS), + setup_sm(Config), + Config; +init_per_testcase(_, Config) -> + set_test_case_meck(infinity), + setup_sm(Config), + Config. + +end_per_testcase(_, Config) -> + clean_sessions(Config), + terminate_sm(), + unload_meck(), + unset_opts(Config). + +open_session(C) -> + {Sid, USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR), + verify_session_opened(C, Sid, USR). + +get_full_session_list(C) -> + ManyUsers = generate_many_random_users(5, [<<"localhost">>, <<"otherhost">>]), + ManyUsersLen = length(ManyUsers), + [given_session_opened(Sid, USR) || {Sid, USR} <- ManyUsers], + AllSessions = ejabberd_sm:get_full_session_list(), + AllSessionsLen = length(AllSessions), + AllSessionsLen = ManyUsersLen, + [verify_session_opened(C, Sid, USR) || {Sid, USR} <- ManyUsers]. + +get_vh_session_list(C) -> + ManyUsersLocal = generate_many_random_users(5, [<<"localhost">>]), + ManyUsersOther = generate_many_random_users(5, [<<"otherhost">>]), + ManyUsersLocalLen = length(ManyUsersLocal), + [given_session_opened(Sid, USR) || {Sid, USR} <- ManyUsersLocal ++ ManyUsersOther], + LocalhostSessions = ejabberd_sm:get_vh_session_list(<<"localhost">>), + LocalhostSessionsLen = length(LocalhostSessions), + LocalhostSessionsLen = ManyUsersLocalLen, + ManyUsersLocalLen = ejabberd_sm:get_vh_session_number(<<"localhost">>), + [verify_session_opened(C, Sid, USR) || {Sid, USR} <- ManyUsersLocal]. + +get_sessions_2(C) -> + UsersWithManyResources = generate_many_random_res(5, 3, [<<"localhost">>, <<"otherhost">>]), + [given_session_opened(Sid, USR) || {Sid, USR} <- UsersWithManyResources], + USDict = get_unique_us_dict(UsersWithManyResources), + [verify_session_opened(C, U, S, dict:fetch({U, S}, USDict)) || {U, S} <- dict:fetch_keys(USDict)], + [verify_session_opened(C, Sid, USR) || {Sid, USR} <- UsersWithManyResources]. + + +get_sessions_3(C) -> + UserRes = generate_many_random_res(1, 3, [<<"localhost">>]), + AllSessions = length(UserRes), + {_, {User, Server, _}} = hd(UserRes), + [given_session_opened(Sid, USR) || {Sid, USR} <- UserRes], + Sessions_2 = ?B(C):get_sessions(User, Server), + AllSessions = length(Sessions_2), + F = fun({Sid, {U, S, R} = USR}) -> + [#session{sid = Sid} = Session] = ?B(C):get_sessions(U, S, R), + Session = lists:keyfind(Sid, #session.sid, Sessions_2), + Session = lists:keyfind(USR, #session.usr, Sessions_2), + true + end, + true = lists:all(F, UserRes). + +session_is_updated_when_created_twice(C) -> + {Sid, {U, S, _} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR), + verify_session_opened(C, Sid, USR), + + given_session_opened(Sid, USR, 20), + verify_session_opened(C, Sid, USR), + + [#session{usr = USR, sid = Sid, priority = 20}] = ?B(C):get_sessions(), + [#session{usr = USR, sid = Sid, priority = 20}] = ?B(C):get_sessions(S), + [#session{priority = 20}] = ?B(C):get_sessions(U, S). + +session_info_is_stored(C) -> + {Sid, {U, S, _} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + [#session{sid = Sid, info = #{key1 := val1}}] + = ?B(C):get_sessions(U,S). + +session_info_is_updated_if_keys_match(C) -> + {Sid, {U, S, _} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + when_session_opened(Sid, USR, 1, [{key1, val2}]), + + [#session{sid = Sid, info = #{key1 := val2}}] + = ?B(C):get_sessions(U,S). + +session_info_is_extended_if_new_keys_present(C) -> + {Sid, {U, S, _} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + when_session_opened(Sid, USR, 1, [{key1, val1}, {key2, val2}]), + + [#session{sid = Sid, info = #{key1 := val1, key2 := val2}}] + = ?B(C):get_sessions(U,S). + +session_info_keys_not_truncated_if_session_opened_with_empty_infolist(C) -> + {Sid, {U, S, _} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + when_session_opened(Sid, USR, 1, []), + + [#session{sid = Sid, info = #{key1 := val1}}] + = ?B(C):get_sessions(U,S). + + +kv_can_be_stored_for_session(C) -> + {Sid, {U, S, R} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + when_session_info_stored(U, S, R, {key2, newval}), + + ?assertMatch([#session{sid = Sid, info = #{key1 := val1, key2 := newval}}], + ?B(C):get_sessions(U,S)). + +kv_can_be_updated_for_session(C) -> + {Sid, {U, S, R} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + when_session_info_stored(U, S, R, {key2, newval}), + when_session_info_stored(U, S, R, {key2, override}), + + ?assertMatch([#session{sid = Sid, info = #{key1 := val1, key2 := override}}], + ?B(C):get_sessions(U, S)). + +kv_can_be_removed_for_session(C) -> + {Sid, {U, S, R} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR, 1, [{key1, val1}]), + + when_session_info_stored(U, S, R, {key2, newval}), + + [#session{sid = Sid, info = #{key1 := val1, key2 := newval}}] + = ?B(C):get_sessions(U, S), + + when_session_info_removed(U, S, R, key2), + + [#session{sid = Sid, info = #{key1 := val1}}] + = ?B(C):get_sessions(U, S), + + when_session_info_removed(U, S, R, key1), + + [#session{sid = Sid, info = #{}}] + = ?B(C):get_sessions(U, S). + +cannot_reproduce_race_condition_in_store_info(C) -> + ok = try_to_reproduce_race_condition(C). + +store_info_sends_message_to_the_session_owner(C) -> + SID = {erlang:system_time(microsecond), self()}, + U = <<"alice2">>, + S = <<"localhost">>, + R = <<"res1">>, + Session = #session{sid = SID, usr = {U, S, R}, us = {U, S}, priority = 1, info = #{}}, + %% Create session in one process + ?B(C):create_session(U, S, R, Session), + %% but call store_info from another process + JID = jid:make_noprep(U, S, R), + spawn_link(fun() -> ejabberd_sm:store_info(JID, cc, undefined) end), + %% The original process receives a message + receive {store_session_info, + #jid{luser = User, lserver = Server, lresource = Resource}, + K, V, _FromPid} -> + ?eq(U, User), + ?eq(S, Server), + ?eq(R, Resource), + ?eq({cc, undefined}, {K, V}), + ok + after 5000 -> + ct:fail("store_info_sends_message_to_the_session_owner=timeout") + end. + +remove_info_sends_message_to_the_session_owner(C) -> + SID = {erlang:timestamp(), self()}, + U = <<"alice2">>, + S = <<"localhost">>, + R = <<"res1">>, + Session = #session{sid = SID, usr = {U, S, R}, us = {U, S}, priority = 1, info = #{}}, + %% Create session in one process + ?B(C):create_session(U, S, R, Session), + %% but call remove_info from another process + JID = jid:make_noprep(U, S, R), + spawn_link(fun() -> ejabberd_sm:remove_info(JID, cc) end), + %% The original process receives a message + receive {remove_session_info, + #jid{luser = User, lserver = Server, lresource = Resource}, + Key, _FromPid} -> + ?eq(U, User), + ?eq(S, Server), + ?eq(R, Resource), + ?eq(cc, Key), + ok + after 5000 -> + ct:fail("remove_info_sends_message_to_the_session_owner=timeout") + end. + +delete_session(C) -> + {Sid, {U, S, R} = USR} = generate_random_user(<<"localhost">>), + given_session_opened(Sid, USR), + verify_session_opened(C, Sid, USR), + + ?B(C):delete_session(Sid, U, S, R), + + [] = ?B(C):get_sessions(), + [] = ?B(C):get_sessions(S), + [] = ?B(C):get_sessions(U, S), + [] = ?B(C):get_sessions(U, S, R). + + + +clean_up(C) -> + UsersWithManyResources = generate_many_random_res(5, 3, [<<"localhost">>, <<"otherhost">>]), + [given_session_opened(Sid, USR) || {Sid, USR} <- UsersWithManyResources], + ?B(C):cleanup(node()), + %% give sm backend some time to clean all sessions + ensure_empty(C, 10, ?B(C):get_sessions()). + +ensure_empty(_C, 0, Sessions) -> + [] = Sessions; +ensure_empty(C, N, Sessions) -> + case Sessions of + [] -> + ok; + _ -> + timer:sleep(50), + ensure_empty(C, N-1, ?B(C):get_sessions()) + end. + +too_much_sessions(_C) -> + %% Max sessions set to ?MAX_USER_SESSIONS in init_per_testcase + UserSessions = [generate_random_user(<<"a">>, <<"localhost">>) || _ <- lists:seq(1, ?MAX_USER_SESSIONS)], + {AddSid, AddUSR} = generate_random_user(<<"a">>, <<"localhost">>), + + [given_session_opened(Sid, USR) || {Sid, USR} <- UserSessions], + + given_session_opened(AddSid, AddUSR), + + receive + replaced -> + ok + after 10 -> + ct:fail("replaced message not sent") + end. + + + +unique_count(_C) -> + UsersWithManyResources = generate_many_random_res(5, 3, [<<"localhost">>, <<"otherhost">>]), + [given_session_opened(Sid, USR) || {Sid, USR} <- UsersWithManyResources], + USDict = get_unique_us_dict(UsersWithManyResources), + UniqueCount = ejabberd_sm:get_unique_sessions_number(), + UniqueCount = dict:size(USDict). + + +unique_count_while_removing_entries(C) -> + unique_count(C), + UniqueCount = ejabberd_sm:get_unique_sessions_number(), + %% Register more sessions and mock the crash + UsersWithManyResources = generate_many_random_res(10, 3, [<<"localhost">>, <<"otherhost">>]), + [given_session_opened(Sid, USR) || {Sid, USR} <- UsersWithManyResources], + set_test_case_meck_unique_count_crash(?B(C)), + USDict = get_unique_us_dict(UsersWithManyResources), + %% Check if unique count equals prev cached value + UniqueCount = ejabberd_sm:get_unique_sessions_number(), + meck:unload(?B(C)), + true = UniqueCount /= dict:size(USDict) + UniqueCount. + +unload_meck() -> + meck:unload(acl), + meck:unload(gen_hook), + meck:unload(ejabberd_commands), + meck:unload(mongoose_domain_api). + +set_test_case_meck(MaxUserSessions) -> + meck:new(acl, []), + meck:expect(acl, match_rule, fun(_, _, _, _) -> MaxUserSessions end), + meck:new(gen_hook, []), + meck:expect(gen_hook, run_fold, fun(_, _, Acc, _) -> {ok, Acc} end), + meck:new(mongoose_domain_api, []), + meck:expect(mongoose_domain_api, get_domain_host_type, fun(H) -> {ok, H} end). + +set_test_case_meck_unique_count_crash(Backend) -> + F = get_fun_for_unique_count(Backend), + meck:new(Backend, []), + meck:expect(Backend, unique_count, F). + +get_fun_for_unique_count(ejabberd_sm_mnesia) -> + fun() -> + mnesia:abort({badarg,[session,{{1442,941593,580189},list_to_pid("<0.23291.6>")}]}) + end; +get_fun_for_unique_count(ejabberd_sm_redis) -> + fun() -> + %% The code below is on purpose, it's to crash with badarg reason + length({error, timeout}) + end. + +make_sid() -> + {erlang:timestamp(), self()}. + +given_session_opened(Sid, USR) -> + given_session_opened(Sid, USR, 1). + +given_session_opened(Sid, {U, S, R}, Priority) -> + given_session_opened(Sid, {U, S, R}, Priority, []). + +given_session_opened(Sid, {U, S, R}, Priority, Info) -> + JID = jid:make_noprep(U, S, R), + ejabberd_sm:open_session(S, Sid, JID, Priority, maps:from_list(Info)). + +when_session_opened(Sid, {U, S, R}, Priority, Info) -> + given_session_opened(Sid, {U, S, R}, Priority, Info). + +when_session_info_stored(U, S, R, {K, V}) -> + JID = jid:make_noprep(U, S, R), + ejabberd_sm:store_info(JID, K, V). + +when_session_info_removed(U, S, R, Key) -> + JID = jid:make_noprep(U, S, R), + ejabberd_sm:remove_info(JID, Key). + +verify_session_opened(C, Sid, USR) -> + do_verify_session_opened(?B(C), Sid, USR). + +do_verify_session_opened(ejabberd_sm_mnesia, Sid, {U, S, R} = USR) -> + general_session_check(ejabberd_sm_mnesia, Sid, USR, U, S, R); +do_verify_session_opened(ejabberd_sm_redis, Sid, {U, S, R} = USR) -> + UHash = iolist_to_binary(hash(U, S, R, Sid)), + Hashes = mongoose_redis:cmd(["SMEMBERS", n(node())]), + true = lists:member(UHash, Hashes), + SessionsUSEncoded = mongoose_redis:cmd(["SMEMBERS", hash(U, S)]), + SessionsUS = [binary_to_term(Entry) || Entry <- SessionsUSEncoded], + true = lists:keymember(Sid, 2, SessionsUS), + [SessionUSREncoded] = mongoose_redis:cmd(["SMEMBERS", hash(U, S, R)]), + SessionUSR = binary_to_term(SessionUSREncoded), + #session{sid = Sid} = SessionUSR, + general_session_check(ejabberd_sm_redis, Sid, USR, U, S, R). + +verify_session_opened(C, U, S, Resources) -> + Sessions = ?B(C):get_sessions(U, S), + F = fun({Sid, USR}) -> + #session{} = Session = lists:keyfind(Sid, #session.sid, Sessions), + Session == lists:keyfind(USR, #session.usr, Sessions) + end, + true = lists:all(F, Resources). + +general_session_check(M, Sid, USR, U, S, R) -> + [#session{sid = Sid, usr = USR, us = {U, S}}] = M:get_sessions(U, S, R). + +clean_sessions(C) -> + case ?B(C) of + ejabberd_sm_mnesia -> + mnesia:clear_table(session); + ejabberd_sm_redis -> + mongoose_redis:cmd(["FLUSHALL"]) + end. + +generate_random_user(S) -> + U = base16:encode(crypto:strong_rand_bytes(5)), + generate_random_user(U, S). + +generate_random_user(U, S) -> + R = base16:encode(crypto:strong_rand_bytes(5)), + generate_user(U, S, R). + +generate_user(U, S, R) -> + Sid = make_sid(), + {Sid, {U, S, R}}. + +generate_many_random_users(PerServerCount, Servers) -> + Users = [generate_random_users(PerServerCount, Server) || Server <- Servers], + lists:flatten(Users). + +generate_random_users(Count, Server) -> + [generate_random_user(Server) || _ <- lists:seq(1, Count)]. + +generate_many_random_res(UsersPerServer, ResourcesPerUser, Servers) -> + Usernames = [base16:encode(crypto:strong_rand_bytes(5)) || _ <- lists:seq(1, UsersPerServer)], + [generate_random_user(U, S) || U <- Usernames, S <- Servers, _ <- lists:seq(1, ResourcesPerUser)]. + +get_unique_us_dict(USRs) -> + F = fun({_, {U, S, _}} = I, SetAcc) -> + dict:append({U, S}, I, SetAcc) + end, + lists:foldl(F, dict:new(), USRs). + +%% Taken from ejabberd_sm_redis + +-spec hash(binary()) -> iolist(). +hash(Val1) -> + ["s3:*:", Val1, ":*"]. + + +-spec hash(binary(), binary()) -> iolist(). +hash(Val1, Val2) -> + ["s2:", Val1, ":", Val2]. + + +-spec hash(binary(), binary(), binary()) -> iolist(). +hash(Val1, Val2, Val3) -> + ["s3:", Val1, ":", Val2, ":", Val3]. + + +-spec hash(binary(), binary(), binary(), binary()) -> iolist(). +hash(Val1, Val2, Val3, Val4) -> + ["s4:", Val1, ":", Val2, ":", Val3, ":", term_to_binary(Val4)]. + + +-spec n(atom()) -> iolist(). +n(Node) -> + ["n:", atom_to_list(Node)]. + + +is_redis_running() -> + case eredis:start_link() of + {ok, Client} -> + Result = eredis:q(Client, [<<"PING">>], 5000), + eredis:stop(Client), + case Result of + {ok,<<"PONG">>} -> + true; + _ -> + false + end; + _ -> + false + end. + +try_to_reproduce_race_condition(Config) -> + SID = {erlang:timestamp(), self()}, + U = <<"alice">>, + S = <<"localhost">>, + R = <<"res1">>, + Session = #session{sid = SID, usr = {U, S, R}, us = {U, S}, priority = 1, info = #{}}, + ?B(Config):create_session(U, S, R, Session), + Parent = self(), + %% Add some instrumentation to simulate race conditions + %% The goal is to delete the session after other process reads it + %% but before it updates it. In other words, delete a record + %% between get_sessions and create_session in ejabberd_sm:store_info + %% Step1 prepare concurrent processes + DeleterPid = spawn_link(fun() -> + receive start -> ok end, + ?B(Config):delete_session(SID, U, S, R), + Parent ! p1_done + end), + SetterPid = spawn_link(fun() -> + receive start -> ok end, + when_session_info_stored(U, S, R, {cc, undefined}), + Parent ! p2_done + end), + %% Step2 setup mocking for some ejabbers_sm_mnesia functions + meck:new(?B(Config), []), + %% When the first get_sessions (run from ejabberd_sm:store_info) + %% is executed, the start msg is sent to Deleter process + %% Thanks to that, setter will get not empty list of sessions + PassThrough3 = fun(A, B, C) -> + DeleterPid ! start, + meck:passthrough([A, B, C]) end, + meck:expect(?B(Config), get_sessions, PassThrough3), + %% Wait some time before setting the sessions + %% so we are sure delete operation finishes + meck:expect(?B(Config), create_session, + fun(U1, S1, R1, Session1) -> + timer:sleep(100), + meck:passthrough([U1, S1, R1, Session1]) + end), + PassThrough4 = fun(A, B, C, D) -> meck:passthrough([A, B, C, D]) end, + meck:expect(?B(Config), delete_session, PassThrough4), + %% Start the play from setter process + SetterPid ! start, + %% Wait for both process to finish + receive p1_done -> ok end, + receive p2_done -> ok end, + meck:unload(?B(Config)), + %% Session should not exist + case ?B(Config):get_sessions(U, S, R) of + [] -> + ok; + Other -> + error_logger:error_msg("issue=reproduced, sid=~p, other=~1000p", + [SID, Other]), + {error, reproduced} + end. + +setup_sm(Config) -> + set_opts(Config), + set_meck(), + ejabberd_sm:start_link(), + case ?config(backend, Config) of + ejabberd_sm_redis -> + mongoose_redis:cmd(["FLUSHALL"]); + ejabberd_sm_mnesia -> + ok + end. + +terminate_sm() -> + gen_server:stop(ejabberd_sm). + +set_opts(Config) -> + [mongoose_config:set_opt(Key, Value) || {Key, Value} <- opts(Config)]. + +unset_opts(Config) -> + [mongoose_config:unset_opt(Key) || {Key, _Value} <- opts(Config)]. + +opts(Config) -> + [{hosts, [<<"localhost">>]}, + {host_types, []}, + {all_metrics_are_global, false}, + {sm_backend, sm_backend(?config(backend, Config))}]. + +sm_backend(ejabberd_sm_redis) -> redis; +sm_backend(ejabberd_sm_mnesia) -> mnesia. + +set_meck() -> + meck:expect(gen_hook, add_handler, fun(_, _, _, _, _) -> ok end), + meck:expect(gen_hook, add_handlers, fun(_) -> ok end), + meck:new(ejabberd_commands, []), + meck:expect(ejabberd_commands, register_commands, fun(_) -> ok end), + meck:expect(ejabberd_commands, unregister_commands, fun(_) -> ok end), + ok.