Skip to content

Commit

Permalink
Extract SASL from C2S
Browse files Browse the repository at this point in the history
  • Loading branch information
NelsonVides committed Aug 21, 2023
1 parent 666ec4b commit 7299e6d
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 81 deletions.
171 changes: 91 additions & 80 deletions src/c2s/mongoose_c2s.erl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
-type stream_state() :: stream_start | authenticated.
-type state(State) :: connect
| {wait_for_stream, stream_state()}
| {wait_for_feature_before_auth, cyrsasl:sasl_state(), retries()}
| {wait_for_feature_before_auth, mongoose_acc:t(), retries()}
| {wait_for_feature_after_auth, retries()}
| {wait_for_sasl_response, cyrsasl:sasl_state(), retries()}
| {wait_for_sasl_response, mongoose_acc:t(), retries()}
| wait_for_session_establishment
| session_established
| ?EXT_C2S_STATE(State).
Expand Down Expand Up @@ -114,24 +114,31 @@ handle_event(internal, #xmlstreamerror{name = <<"child element too big">> = Err}
c2s_stream_error(StateData, mongoose_xmpp_errors:policy_violation(StateData#c2s_data.lang, Err));
handle_event(internal, #xmlstreamerror{name = Err}, _, StateData) ->
c2s_stream_error(StateData, mongoose_xmpp_errors:xml_not_well_formed(StateData#c2s_data.lang, Err));
handle_event(internal, #xmlel{name = <<"starttls">>} = El, {wait_for_feature_before_auth, SaslState, Retries}, StateData) ->
handle_event(internal, #xmlel{name = <<"starttls">>} = El, {wait_for_feature_before_auth, SaslAcc, Retries}, StateData) ->
case exml_query:attr(El, <<"xmlns">>) of
?NS_TLS ->
handle_starttls(StateData, El, SaslState, Retries);
handle_starttls(StateData, El, SaslAcc, Retries);
_ ->
c2s_stream_error(StateData, mongoose_xmpp_errors:invalid_namespace())
end;
handle_event(internal, #xmlel{name = <<"auth">>} = El, {wait_for_feature_before_auth, SaslState, Retries}, StateData) ->
handle_event(internal, #xmlel{name = <<"auth">>} = El, {wait_for_feature_before_auth, SaslAcc, Retries}, StateData) ->
case exml_query:attr(El, <<"xmlns">>) of
?NS_SASL ->
handle_auth_start(StateData, El, SaslState, Retries);
handle_auth_start(StateData, El, SaslAcc, Retries);
_ ->
c2s_stream_error(StateData, mongoose_xmpp_errors:invalid_namespace())
end;
handle_event(internal, #xmlel{name = <<"response">>} = El, {wait_for_sasl_response, SaslState, Retries}, StateData) ->
handle_event(internal, #xmlel{name = <<"response">>} = El, {wait_for_sasl_response, SaslAcc, Retries}, StateData) ->
case exml_query:attr(El, <<"xmlns">>) of
?NS_SASL ->
handle_auth_continue(StateData, El, SaslState, Retries);
handle_auth_continue(StateData, El, SaslAcc, Retries);
_ ->
c2s_stream_error(StateData, mongoose_xmpp_errors:invalid_namespace())

Check warning on line 136 in src/c2s/mongoose_c2s.erl

View check run for this annotation

Codecov / codecov/patch

src/c2s/mongoose_c2s.erl#L136

Added line #L136 was not covered by tests
end;
handle_event(internal, #xmlel{name = <<"abort">>} = El, {wait_for_sasl_response, SaslAcc, Retries}, StateData) ->
case exml_query:attr(El, <<"xmlns">>) of

Check warning on line 139 in src/c2s/mongoose_c2s.erl

View check run for this annotation

Codecov / codecov/patch

src/c2s/mongoose_c2s.erl#L139

Added line #L139 was not covered by tests
?NS_SASL ->
handle_sasl_abort(StateData, SaslAcc, Retries);

Check warning on line 141 in src/c2s/mongoose_c2s.erl

View check run for this annotation

Codecov / codecov/patch

src/c2s/mongoose_c2s.erl#L141

Added line #L141 was not covered by tests
_ ->
c2s_stream_error(StateData, mongoose_xmpp_errors:invalid_namespace())
end;
Expand Down Expand Up @@ -354,10 +361,10 @@ get_xml_lang(StreamStart) ->
?MYLANG
end.

-spec handle_starttls(data(), exml:element(), cyrsasl:sasl_state(), retries()) -> fsm_res().
-spec handle_starttls(data(), exml:element(), mongoose_acc:t(), retries()) -> fsm_res().
handle_starttls(StateData = #c2s_data{socket = TcpSocket,
parser = Parser,
listener_opts = LOpts}, El, SaslState, Retries) ->
listener_opts = LOpts}, El, SaslAcc, Retries) ->
send_xml(StateData, mongoose_c2s_stanzas:tls_proceed()), %% send last negotiation chunk via tcp
case mongoose_c2s_socket:tcp_to_tls(TcpSocket, LOpts) of
{ok, TlsSocket} ->
Expand All @@ -371,7 +378,7 @@ handle_starttls(StateData = #c2s_data{socket = TcpSocket,
ErrorStanza = mongoose_xmpp_errors:bad_request(StateData#c2s_data.lang, <<"bad_config">>),
Err = jlib:make_error_reply(El, ErrorStanza),
send_element_from_server_jid(StateData, Err),
maybe_retry_state(StateData, {wait_for_feature_before_auth, SaslState, Retries});
maybe_retry_state(StateData, {wait_for_feature_before_auth, SaslAcc, Retries});

Check warning on line 381 in src/c2s/mongoose_c2s.erl

View check run for this annotation

Codecov / codecov/patch

src/c2s/mongoose_c2s.erl#L381

Added line #L381 was not covered by tests
{error, closed} ->
{stop, {shutdown, tls_closed}};
{error, timeout} ->
Expand All @@ -380,77 +387,72 @@ handle_starttls(StateData = #c2s_data{socket = TcpSocket,
{stop, TlsAlert}
end.

-spec handle_auth_start(data(), exml:element(), cyrsasl:sasl_state(), retries()) -> fsm_res().
handle_auth_start(StateData, El, SaslState, Retries) ->
case {mongoose_c2s_socket:is_ssl(StateData#c2s_data.socket), StateData#c2s_data.listener_opts} of
{false, #{tls := #{mode := starttls_required}}} ->
Error = mongoose_xmpp_errors:policy_violation(
StateData#c2s_data.lang, <<"Use of STARTTLS required">>),
c2s_stream_error(StateData, Error);
_ ->
do_handle_auth_start(StateData, El, SaslState, Retries)
end.

-spec do_handle_auth_start(data(), exml:element(), cyrsasl:sasl_state(), retries()) -> fsm_res().
do_handle_auth_start(StateData, El, SaslState, Retries) ->
-spec handle_auth_start(data(), exml:element(), mongoose_acc:t(), retries()) -> fsm_res().
handle_auth_start(StateData, El, SaslAcc, Retries) ->
Mech = exml_query:attr(El, <<"mechanism">>),
ClientIn = base64:mime_decode(exml_query:cdata(El)),
AuthMech = get_auth_mechs(StateData),
SocketData = #{socket => StateData#c2s_data.socket, auth_mech => AuthMech,
listener_opts => StateData#c2s_data.listener_opts},
StepResult = cyrsasl:server_start(SaslState, Mech, ClientIn, SocketData),
handle_sasl_step(StateData, StepResult, SaslState, Retries).

-spec handle_auth_continue(data(), exml:element(), cyrsasl:sasl_state(), retries()) -> fsm_res().
handle_auth_continue(StateData, El, SaslState, Retries) ->
StepResult = mongoose_c2s_sasl:start(StateData, SaslAcc, Mech, ClientIn),
handle_sasl_step(StateData, StepResult, Retries).

-spec handle_auth_continue(data(), exml:element(), mongoose_acc:t(), retries()) -> fsm_res().
handle_auth_continue(StateData, El, SaslAcc, Retries) ->
ClientIn = base64:mime_decode(exml_query:cdata(El)),
StepResult = cyrsasl:server_step(SaslState, ClientIn),
handle_sasl_step(StateData, StepResult, SaslState, Retries).

-spec handle_sasl_step(data(), cyrsasl:sasl_result(), cyrsasl:sasl_state(), retries()) -> fsm_res().
handle_sasl_step(StateData, {ok, Creds}, _, _) ->
handle_sasl_success(StateData, Creds);
handle_sasl_step(StateData = #c2s_data{listener_opts = LOpts}, {continue, ServerOut, NewSaslState}, _, Retries) ->
Challenge = [#xmlcdata{content = jlib:encode_base64(ServerOut)}],
send_element_from_server_jid(StateData, mongoose_c2s_stanzas:sasl_challenge_stanza(Challenge)),
{next_state, {wait_for_sasl_response, NewSaslState, Retries}, StateData, state_timeout(LOpts)};
handle_sasl_step(#c2s_data{host_type = HostType, lserver = Server} = StateData,
{error, Error, Username}, SaslState, Retries) ->
?LOG_INFO(#{what => auth_failed,
text => <<"Failed SASL authentication">>,
user => Username, lserver => Server, c2s_state => StateData}),
mongoose_hooks:auth_failed(HostType, Server, Username),
send_element_from_server_jid(StateData, mongoose_c2s_stanzas:sasl_failure_stanza(Error)),
maybe_retry_state(StateData, {wait_for_feature_before_auth, SaslState, Retries});
handle_sasl_step(#c2s_data{host_type = HostType, lserver = Server} = StateData,
{error, Error}, SaslState, Retries) ->
mongoose_hooks:auth_failed(HostType, Server, unknown),
send_element_from_server_jid(StateData, mongoose_c2s_stanzas:sasl_failure_stanza(Error)),
maybe_retry_state(StateData, {wait_for_feature_before_auth, SaslState, Retries}).

-spec handle_sasl_success(data(), term()) -> fsm_res().
handle_sasl_success(State = #c2s_data{listener_opts = LOpts}, Creds) ->
ServerOut = mongoose_credentials:get(Creds, sasl_success_response, undefined),
AuthModule = mongoose_credentials:get(Creds, auth_module),
send_element_from_server_jid(State, mongoose_c2s_stanzas:sasl_success_stanza(ServerOut)),
User = mongoose_credentials:get(Creds, username),
NewState = State#c2s_data{streamid = new_stream_id(),
jid = jid:make_bare(User, State#c2s_data.lserver),
auth_module = AuthModule},
?LOG_INFO(#{what => auth_success, text => <<"Accepted SASL authentication">>,
c2s_state => NewState}),
{next_state, {wait_for_stream, authenticated}, NewState, state_timeout(LOpts)}.
StepResult = mongoose_c2s_sasl:continue(StateData, SaslAcc, ClientIn),
handle_sasl_step(StateData, StepResult, Retries).

-spec handle_sasl_step(data(), mongoose_c2s_sasl:result(), retries()) -> fsm_res().
handle_sasl_step(StateData, {success, NewSaslAcc, Result}, _Retries) ->
handle_sasl_success(StateData, NewSaslAcc, Result);
handle_sasl_step(StateData, {continue, NewSaslAcc, Result}, Retries) ->
handle_sasl_continue(StateData, NewSaslAcc, Result, Retries);
handle_sasl_step(StateData, {failure, NewSaslAcc, Result}, Retries) ->
handle_sasl_failure(StateData, NewSaslAcc, Result, Retries);
handle_sasl_step(StateData, {error, NewSaslAcc, Result}, Retries) ->
handle_sasl_error(StateData, NewSaslAcc, Result, Retries).

-spec handle_sasl_success(data(), mongoose_acc:t(), mongoose_c2s_sasl:success()) -> fsm_res().
handle_sasl_success(StateData, SaslAcc,
#{server_out := MaybeServerOut, jid := Jid, auth_module := AuthMod}) ->
StateData1 = StateData#c2s_data{streamid = new_stream_id(), jid = Jid, auth_module = AuthMod},
El = mongoose_c2s_stanzas:sasl_success_stanza(MaybeServerOut),
send_acc_from_server_jid(StateData1, SaslAcc, El),
?LOG_INFO(#{what => auth_success, text => <<"Accepted SASL authentication">>, c2s_state => StateData1}),
{next_state, {wait_for_stream, authenticated}, StateData1, state_timeout(StateData1)}.

-spec handle_sasl_continue(data(), mongoose_acc:t(), mongoose_c2s_sasl:continue(), retries()) -> fsm_res().
handle_sasl_continue(StateData, SaslAcc, #{server_out := ServerOut}, Retries) ->
El = mongoose_c2s_stanzas:sasl_challenge_stanza(ServerOut),
NewSaslAcc = send_acc_from_server_jid(StateData, SaslAcc, El),
{next_state, {wait_for_sasl_response, NewSaslAcc, Retries}, StateData, state_timeout(StateData)}.

-spec handle_sasl_failure(data(), mongoose_acc:t(), mongoose_c2s_sasl:failure(), retries()) -> fsm_res().
handle_sasl_failure(#c2s_data{host_type = HostType, lserver = LServer} = StateData, SaslAcc,
#{server_out := ServerOut, maybe_username := Username}, Retries) ->
?LOG_INFO(#{what => auth_failed, text => <<"Failed SASL authentication">>,
jid => Username, c2s_state => StateData}),
mongoose_hooks:auth_failed(HostType, LServer, Username),
El = mongoose_c2s_stanzas:sasl_failure_stanza(ServerOut),
NewSaslAcc = send_acc_from_server_jid(StateData, SaslAcc, El),
maybe_retry_state(StateData, {wait_for_feature_before_auth, NewSaslAcc, Retries}).

-spec handle_sasl_error(data(), mongoose_acc:t(), mongoose_c2s_sasl:error(), retries()) -> fsm_res().
handle_sasl_error(#c2s_data{lang = Lang} = StateData, _SaslAcc,
#{type := Type, text := Text}, _Retries) ->
Error = mongoose_xmpp_errors:Type(Lang, Text),
c2s_stream_error(StateData, Error).

-spec handle_sasl_abort(data(), mongoose_acc:t(), retries()) -> fsm_res().
handle_sasl_abort(StateData, SaslAcc, Retries) ->
Error = #{server_out => <<"aborted">>, maybe_username => StateData#c2s_data.jid},
handle_sasl_failure(StateData, SaslAcc, Error, Retries).

Check warning on line 447 in src/c2s/mongoose_c2s.erl

View check run for this annotation

Codecov / codecov/patch

src/c2s/mongoose_c2s.erl#L446-L447

Added lines #L446 - L447 were not covered by tests

-spec stream_start_features_before_auth(data()) -> fsm_res().
stream_start_features_before_auth(#c2s_data{host_type = HostType, lserver = LServer,
listener_opts = LOpts} = StateData) ->
stream_start_features_before_auth(#c2s_data{listener_opts = LOpts} = StateData) ->
send_header(StateData),
CredOpts = mongoose_credentials:make_opts(LOpts),
Creds = mongoose_credentials:new(LServer, HostType, CredOpts),
SASLState = cyrsasl:server_new(<<"jabber">>, LServer, HostType, <<>>, [], Creds),
SaslAcc = mongoose_c2s_sasl:new(StateData),
StreamFeatures = mongoose_c2s_stanzas:stream_features_before_auth(StateData),
send_element_from_server_jid(StateData, StreamFeatures),
{next_state, {wait_for_feature_before_auth, SASLState, ?AUTH_RETRIES}, StateData, state_timeout(LOpts)}.
SaslAcc1 = send_acc_from_server_jid(StateData, SaslAcc, StreamFeatures),
{next_state, {wait_for_feature_before_auth, SaslAcc1, ?AUTH_RETRIES}, StateData, state_timeout(LOpts)}.

-spec stream_start_features_after_auth(data()) -> fsm_res().
stream_start_features_after_auth(#c2s_data{listener_opts = LOpts} = StateData) ->
Expand Down Expand Up @@ -640,10 +642,10 @@ maybe_retry_state({wait_for_sasl_response, _, 0}) ->
{stop, {shutdown, retries}};
maybe_retry_state({wait_for_feature_after_auth, Retries}) ->
{wait_for_feature_after_auth, Retries - 1};
maybe_retry_state({wait_for_feature_before_auth, SaslState, Retries}) ->
{wait_for_feature_before_auth, SaslState, Retries - 1};
maybe_retry_state({wait_for_sasl_response, SaslState, Retries}) ->
{wait_for_sasl_response, SaslState, Retries - 1};
maybe_retry_state({wait_for_feature_before_auth, SaslAcc, Retries}) ->
{wait_for_feature_before_auth, SaslAcc, Retries - 1};
maybe_retry_state({wait_for_sasl_response, SaslAcc, Retries}) ->
{wait_for_sasl_response, SaslAcc, Retries - 1};

Check warning on line 648 in src/c2s/mongoose_c2s.erl

View check run for this annotation

Codecov / codecov/patch

src/c2s/mongoose_c2s.erl#L648

Added line #L648 was not covered by tests
maybe_retry_state(?EXT_C2S_STATE(_) = State) ->
State.

Expand Down Expand Up @@ -944,6 +946,13 @@ send_element_from_server_jid(StateData, El) ->
element => El}),
do_send_element(StateData, Acc, El).

-spec send_acc_from_server_jid(data(), mongoose_acc:t(), exml:element()) -> mongoose_acc:t().
send_acc_from_server_jid(StateData = #c2s_data{lserver = LServer, jid = Jid}, Acc0, El) ->
ServerJid = jid:make_noprep(<<>>, LServer, <<>>),
ParamsAcc = #{from_jid => ServerJid, to_jid => Jid, element => El},
Acc1 = mongoose_acc:update_stanza(ParamsAcc, Acc0),
do_send_element(StateData, Acc1, El).

-spec maybe_send_xml(data(), mongoose_acc:t(), exml:element()) -> mongoose_acc:t().
maybe_send_xml(StateData = #c2s_data{host_type = HostType, lserver = LServer}, undefined, ToSend) ->
Acc = mongoose_acc:new(#{host_type => HostType, lserver => LServer, location => ?LOCATION}),
Expand All @@ -969,6 +978,8 @@ send_xml(#c2s_data{socket = Socket}, XmlElements) when is_list(XmlElements) ->
mongoose_c2s_socket:send_xml(Socket, XmlElements).


state_timeout(#c2s_data{listener_opts = LOpts}) ->
state_timeout(LOpts);
state_timeout(#{c2s_state_timeout := Timeout}) ->
{state_timeout, Timeout, state_timeout_termination}.

Expand Down Expand Up @@ -1040,7 +1051,7 @@ cast(Pid, EventTag, EventContent) ->
create_data(#{host_type := HostType, jid := Jid}) ->
#c2s_data{host_type = HostType, jid = Jid}.

-spec get_auth_mechs(data()) -> [cyrsasl:mechanism()].
-spec get_auth_mechs(data()) -> [mongoose_c2s_sasl:mechanism()].
get_auth_mechs(#c2s_data{host_type = HostType} = StateData) ->
[M || M <- cyrsasl:listmech(HostType), filter_mechanism(StateData, M)].

Expand Down
Loading

0 comments on commit 7299e6d

Please sign in to comment.