diff --git a/lib/phoenix/channel/transport.ex b/lib/phoenix/channel/transport.ex index 6479db5d2f..6da9cb2dfd 100644 --- a/lib/phoenix/channel/transport.ex +++ b/lib/phoenix/channel/transport.ex @@ -250,31 +250,34 @@ defmodule Phoenix.Channel.Transport do Otherwise a otherwise a 403 Forbidden response will be sent and the connection halted. It is a noop if the connection has been halted. """ - def check_origin(conn, allowed_origins, sender \\ &Plug.Conn.send_resp/1) + def check_origin(conn, endpoint, check_origin, sender \\ &Plug.Conn.send_resp/1) - def check_origin(%Plug.Conn{halted: true} = conn, _allowed_origins, _sender) do - conn - end + def check_origin(%Plug.Conn{halted: true} = conn, _endpoint, _check_origin, _sender), + do: conn - def check_origin(conn, allowed_origins, sender) do + def check_origin(conn, endpoint, check_origin, sender) do import Plug.Conn origin = get_req_header(conn, "origin") |> List.first - if origin_allowed?(origin, allowed_origins) do - conn - else - resp(conn, :forbidden, "") - |> sender.() - |> halt() + cond do + is_nil(origin) -> + conn + origin_allowed?(check_origin, origin, endpoint) -> + conn + true -> + resp(conn, :forbidden, "") + |> sender.() + |> halt() end end - defp origin_allowed?(nil, _) do - true - end - defp origin_allowed?(_, nil) do - true - end + defp origin_allowed?(false, _, _), + do: true + defp origin_allowed?(true, origin, endpoint), + do: compare?(URI.parse(origin).host, endpoint.config(:url)[:host]) + defp origin_allowed?(check_origin, origin, _endpoint) when is_list(check_origin), + do: origin_allowed?(origin, check_origin) + defp origin_allowed?(origin, allowed_origins) do origin = URI.parse(origin) diff --git a/lib/phoenix/transports/long_poll.ex b/lib/phoenix/transports/long_poll.ex index b9fc9c13fd..cc70f32adc 100644 --- a/lib/phoenix/transports/long_poll.ex +++ b/lib/phoenix/transports/long_poll.ex @@ -1,47 +1,47 @@ defmodule Phoenix.Transports.LongPoll do @moduledoc """ - Handles LongPoll clients for the Channel Transport layer. + Socket transport for long poll clients. ## Configuration - The long poller is configurable in your Socket's transport configuration: + The long poll is configurable in your socket: transport :longpoll, Phoenix.Transports.LongPoll, window_ms: 10_000, pubsub_timeout_ms: 1000, - crypto: [iterations: 1000, - length: 32, - digest: :sha256, - cache: Plug.Keys], + log: false, + check_origin: true, + crypto: [max_age: 1209600] * `:window_ms` - how long the client can wait for new messages - in it's poll request. + in its poll request + * `:pubsub_timeout_ms` - how long a request can wait for the - pubsub layer to respond. - * `:crypto` - configuration for the key generated to sign the - private topic used for the long poller session (see `Plug.Crypto.KeyGenerator`). + pubsub layer to respond + + * `:crypto` - options for verifying and signing the token, accepted + by `Phoenix.Token`. By default tokens are valid for 2 weeks + + * `:log` - if the transport layer itself should log and, if so, the level + + * `:check_origin` - if we should check the origin of requests when the + origin header is present. It defaults to true and, in such cases, + it will check against the host value in `YourApp.Endpoint.config(:url)[:host]`. + It may be set to `false` (not recommended) or to a list of explicitly + allowed origins """ ## Transport callbacks @behaviour Phoenix.Channel.Transport - @doc """ - Provides the deault transport configuration to sockets. - - * `:serializer` - The `Phoenix.Socket.Message` serializer - * `:pubsub_timeout_ms` - The timeout to wait for the LongPoll.Server ack - * `:log` - The log level, for example `:info`. Disabled by default - * `:timeout` - The connection timeout in milliseconds, defaults to `:infinity` - * `:crypto` - The list of encryption options for the `Plug.Session` - """ def default_config() do [window_ms: 10_000, pubsub_timeout_ms: 1000, serializer: Phoenix.Transports.LongPollSerializer, log: false, - crypto: [iterations: 1000, length: 32, - digest: :sha256, cache: Plug.Keys]] + check_origin: true, + crypto: [max_age: 1209600]] end def handler_for(:cowboy), do: Plug.Adapters.Cowboy.Handler @@ -58,10 +58,12 @@ defmodule Phoenix.Transports.LongPoll do alias Phoenix.Transports.LongPoll alias Phoenix.Channel.Transport + @doc false def init(opts) do opts end + @doc false def call(conn, {endpoint, handler, transport}) do {_, opts} = handler.__transport__(transport) @@ -70,7 +72,7 @@ defmodule Phoenix.Transports.LongPoll do |> Plug.Conn.fetch_query_params |> Transport.transport_log(opts[:log]) |> Transport.force_ssl(handler, endpoint) - |> Transport.check_origin(opts[:origins], &status_json(&1, %{})) + |> Transport.check_origin(endpoint, opts[:check_origin], &status_json(&1, %{})) |> dispatch(endpoint, handler, transport, opts) end @@ -167,14 +169,14 @@ defmodule Phoenix.Transports.LongPoll do child = [socket, opts[:window_ms], priv_topic] {:ok, server_pid} = Supervisor.start_child(LongPoll.Supervisor, child) - {priv_topic, sign_token(endpoint, priv_topic), server_pid} + {priv_topic, sign_token(endpoint, priv_topic, opts), server_pid} end # Retrieves the serialized `Phoenix.LongPoll.Server` pid # by publishing a message in the encrypted private topic. @doc false def resume_session(%{"token" => token}, endpoint, opts) do - case verify_token(endpoint, token) do + case verify_token(endpoint, token, opts) do {:ok, priv_topic} -> ref = :erlang.make_ref() :ok = subscribe(endpoint, priv_topic) @@ -225,12 +227,12 @@ defmodule Phoenix.Transports.LongPoll do Phoenix.PubSub.broadcast_from(endpoint.__pubsub_server__, self, priv_topic, msg) end - defp sign_token(endpoint, priv_topic) do - Phoenix.Token.sign(endpoint, Atom.to_string(endpoint.__pubsub_server__), priv_topic) + defp sign_token(endpoint, priv_topic, opts) do + Phoenix.Token.sign(endpoint, Atom.to_string(endpoint.__pubsub_server__), priv_topic, opts[:crypto]) end - defp verify_token(endpoint, signed) do - Phoenix.Token.verify(endpoint, Atom.to_string(endpoint.__pubsub_server__), signed) + defp verify_token(endpoint, signed, opts) do + Phoenix.Token.verify(endpoint, Atom.to_string(endpoint.__pubsub_server__), signed, opts[:crypto]) end defp status_json(conn, data) do diff --git a/lib/phoenix/transports/long_poll/server.ex b/lib/phoenix/transports/long_poll/server.ex index 6ccfca3d7a..ff829ba47b 100644 --- a/lib/phoenix/transports/long_poll/server.ex +++ b/lib/phoenix/transports/long_poll/server.ex @@ -15,10 +15,10 @@ defmodule Phoenix.Transports.LongPoll.Supervisor do end defmodule Phoenix.Transports.LongPoll.Server do - use GenServer - @moduledoc false + use GenServer + alias Phoenix.Channel.Transport alias Phoenix.PubSub alias Phoenix.Socket.Broadcast @@ -38,14 +38,15 @@ defmodule Phoenix.Transports.LongPoll.Server do GenServer.start_link(__MODULE__, [socket, window_ms, priv_topic]) end - @doc false + ## Callbacks + def init([socket, window_ms, priv_topic]) do Process.flag(:trap_exit, true) state = %{buffer: [], socket: socket, - sockets: HashDict.new, - sockets_inverse: HashDict.new, + channels: HashDict.new, + channels_inverse: HashDict.new, window_ms: trunc(window_ms * 1.5), pubsub_server: socket.endpoint.__pubsub_server__(), priv_topic: priv_topic, @@ -59,23 +60,18 @@ defmodule Phoenix.Transports.LongPoll.Server do {:ok, state} end - @doc """ - Stops the server - """ def handle_call(:stop, _from, state), do: {:stop, :shutdown, :ok, state} - @doc """ - Dispatches client message back through Transport layer. - """ + # Handle client dispatches def handle_info({:dispatch, msg, ref}, state) do msg - |> Transport.dispatch(state.sockets, self, state.socket) + |> Transport.dispatch(state.channels, self, state.socket) |> case do - {:ok, socket_pid, reply_msg} -> + {:ok, channel_pid, reply_msg} -> :ok = broadcast_from(state, {:ok, :dispatch, ref}) - new_state = %{state | sockets: HashDict.put(state.sockets, msg.topic, socket_pid), - sockets_inverse: HashDict.put(state.sockets_inverse, socket_pid, msg.topic)} + new_state = %{state | channels: HashDict.put(state.channels, msg.topic, channel_pid), + channels_inverse: HashDict.put(state.channels_inverse, channel_pid, msg.topic)} publish_reply(reply_msg, new_state) {:ok, reply_msg} -> @@ -92,39 +88,31 @@ defmodule Phoenix.Transports.LongPoll.Server do end end - @doc """ - Forwards replied/broadcasted message from Channels back to client. - """ + # Forwards replied/broadcasted message from Channels back to client. def handle_info(%Message{} = msg, state) do publish_encoded_reply(msg, state) end - @doc """ - Detects disconnect broadcasts and shuts down - """ + # Detects disconnect broadcasts and shuts down def handle_info(%Broadcast{event: "disconnect"}, state) do {:stop, {:shutdown, :disconnected}, state} end - @doc """ - Crash if pubsub adapter goes down - """ + # Crash if pubsub adapter goes down def handle_info({:EXIT, pub_pid, :shutdown}, %{pubsub_server: pub_pid} = state) do - {:stop, :pubsub_server_terminated, state} + {:stop, {:shutdown,:pubsub_server_terminated}, state} end - @doc """ - Trap channel process exits and notify client of close or error events - - `:normal` exits and shutdowns indicate the channel shutdown gracefully from - return. Any other exit reason is treated as an error. - """ - def handle_info({:EXIT, socket_pid, reason}, state) do - case HashDict.get(state.sockets_inverse, socket_pid) do + # Trap channel process exits and notify client of close or error events + # + # Normal exits and shutdowns indicate the channel shutdown gracefully + # from return. Any other exit reason is treated as an error. + def handle_info({:EXIT, channel_pid, reason}, state) do + case HashDict.get(state.channels_inverse, channel_pid) do nil -> {:noreply, state} topic -> - new_state = %{state | sockets: HashDict.delete(state.sockets, topic), - sockets_inverse: HashDict.delete(state.sockets_inverse, socket_pid)} + new_state = %{state | channels: HashDict.delete(state.channels, topic), + channels_inverse: HashDict.delete(state.channels_inverse, channel_pid)} case reason do :normal -> publish_reply(Transport.chan_close_message(topic), new_state) @@ -138,7 +126,6 @@ defmodule Phoenix.Transports.LongPoll.Server do def handle_info({:subscribe, ref}, state) do :ok = broadcast_from(state, {:ok, :subscribe, ref}) - {:noreply, state} end diff --git a/lib/phoenix/transports/websocket.ex b/lib/phoenix/transports/websocket.ex index 79a35c4edf..92cb540779 100644 --- a/lib/phoenix/transports/websocket.ex +++ b/lib/phoenix/transports/websocket.ex @@ -1,28 +1,59 @@ defmodule Phoenix.Transports.WebSocket do @moduledoc """ - Handles WebSocket clients for the Channel Transport layer. + Socket transport for websocket clients. ## Configuration - By default, JSON encoding is used to broker messages to and from clients and - Websockets, by default, do not timeout if the connection is lost. The - maximum timeout duration and serializer can be configured in your Socket's - transport configuration: + The websocket is configurable in your socket: transport :websocket, Phoenix.Transports.WebSocket, - serializer: MySerializer - timeout: 60000 + timeout: :infinity, + serializer: Phoenix.Transports.WebSocketSerializer, + log: false, + check_origin: true - The `serializer` module needs only to implement the `encode!/1` and - `decode!/2` functions defined by the `Phoenix.Transports.Serializer` behaviour. + * `:timeout` - the timeout for keeping websocket connections + open after it last received data + + * `:log` - if the transport layer itself should log and, if so, the level + + * `:serializer` - the serializer for websocket messages + + * `:check_origin` - if we should check the origin of requests when the + origin header is present. It defaults to true and, in such cases, + it will check against the host value in `YourApp.Endpoint.config(:url)[:host]`. + It may be set to `false` (not recommended) or to a list of explicitly + allowed origins + + ## Serializer + + By default, JSON encoding is used to broker messages to and from clients. + A custom serializer may be given as module which implements the `encode!/1` + and `decode!/2` functions defined by the `Phoenix.Transports.Serializer` + behaviour. + + The `encode!/1` function must return a tuple in the format + `{:socket_push, :text | :binary, String.t | binary}`. """ @behaviour Phoenix.Channel.Transport + def default_config() do + [serializer: Phoenix.Transports.WebSocketSerializer, + timeout: :infinity, + log: false, + check_origin: true] + end + + def handler_for(:cowboy), do: Phoenix.Endpoint.CowboyWebSocket + + ## Callbacks + import Plug.Conn, only: [fetch_query_params: 1, send_resp: 3] alias Phoenix.Socket.Broadcast alias Phoenix.Channel.Transport + @doc false def init(%Plug.Conn{method: "GET"} = conn, {endpoint, handler, transport}) do {_, opts} = handler.__transport__(transport) @@ -31,7 +62,7 @@ defmodule Phoenix.Transports.WebSocket do |> Plug.Conn.fetch_query_params |> Transport.transport_log(opts[:log]) |> Transport.force_ssl(handler, endpoint) - |> Transport.check_origin(opts[:origins]) + |> Transport.check_origin(endpoint, opts[:check_origin]) case conn do %{halted: false} = conn -> @@ -55,24 +86,7 @@ defmodule Phoenix.Transports.WebSocket do {:error, conn} end - @doc """ - Provides the deault transport configuration to sockets. - - * `:serializer` - The `Phoenix.Socket.Message` serializer - * `:log` - The log level, for example `:info`. Disabled by default - * `:timeout` - The connection timeout in milliseconds, defaults to `:infinity` - """ - def default_config() do - [serializer: Phoenix.Transports.WebSocketSerializer, - timeout: :infinity, - log: false] - end - - def handler_for(:cowboy), do: Phoenix.Endpoint.CowboyWebSocket - - @doc """ - Handles initalization of the websocket. - """ + @doc false def ws_init({socket, config}) do Process.flag(:trap_exit, true) serializer = Keyword.fetch!(config, :serializer) @@ -81,21 +95,18 @@ defmodule Phoenix.Transports.WebSocket do if socket.id, do: socket.endpoint.subscribe(self, socket.id, link: true) {:ok, %{socket: socket, - sockets: HashDict.new, - sockets_inverse: HashDict.new, + channels: HashDict.new, + channels_inverse: HashDict.new, serializer: serializer}, timeout} end - @doc """ - Receives JSON encoded `%Phoenix.Socket.Message{}` from client and dispatches - to Transport layer. - """ + @doc false def ws_handle(opcode, payload, state) do msg = state.serializer.decode!(payload, opcode: opcode) - case Transport.dispatch(msg, state.sockets, self, state.socket) do - {:ok, socket_pid, reply_msg} -> - format_reply(state.serializer.encode!(reply_msg), put(state, msg.topic, socket_pid)) + case Transport.dispatch(msg, state.channels, self, state.socket) do + {:ok, channel_pid, reply_msg} -> + format_reply(state.serializer.encode!(reply_msg), put(state, msg.topic, channel_pid)) {:ok, reply_msg} -> format_reply(state.serializer.encode!(reply_msg), state) :ok -> @@ -106,11 +117,12 @@ defmodule Phoenix.Transports.WebSocket do end end - def ws_info({:EXIT, socket_pid, reason}, state) do - case HashDict.get(state.sockets_inverse, socket_pid) do + @doc false + def ws_info({:EXIT, channel_pid, reason}, state) do + case HashDict.get(state.channels_inverse, channel_pid) do nil -> {:ok, state} topic -> - new_state = delete(state, topic, socket_pid) + new_state = delete(state, topic, channel_pid) case reason do :normal -> @@ -125,14 +137,12 @@ defmodule Phoenix.Transports.WebSocket do end end - @doc """ - Detects disconnect broadcasts and shuts down - """ + @doc false def ws_info(%Broadcast{event: "disconnect"}, state) do {:shutdown, state} end - def ws_info({:socket_push, :text, _encoded_payload} = msg, state) do + def ws_info({:socket_push, _, _encoded_payload} = msg, state) do format_reply(msg, state) end @@ -140,27 +150,29 @@ defmodule Phoenix.Transports.WebSocket do {:ok, state} end + @doc false def ws_terminate(_reason, _state) do :ok end + @doc false def ws_close(state) do - for {pid, _} <- state.sockets_inverse do + for {pid, _} <- state.channels_inverse do Phoenix.Channel.Server.close(pid) end end - defp put(state, topic, socket_pid) do - %{state | sockets: HashDict.put(state.sockets, topic, socket_pid), - sockets_inverse: HashDict.put(state.sockets_inverse, socket_pid, topic)} + defp put(state, topic, channel_pid) do + %{state | channels: HashDict.put(state.channels, topic, channel_pid), + channels_inverse: HashDict.put(state.channels_inverse, channel_pid, topic)} end - defp delete(state, topic, socket_pid) do - %{state | sockets: HashDict.delete(state.sockets, topic), - sockets_inverse: HashDict.delete(state.sockets_inverse, socket_pid)} + defp delete(state, topic, channel_pid) do + %{state | channels: HashDict.delete(state.channels, topic), + channels_inverse: HashDict.delete(state.channels_inverse, channel_pid)} end - defp format_reply({:socket_push, :text, encoded_payload}, state) do - {:reply, {:text, encoded_payload}, state} + defp format_reply({:socket_push, encoding, encoded_payload}, state) do + {:reply, {encoding, encoded_payload}, state} end end diff --git a/test/phoenix/channel/transport_test.exs b/test/phoenix/channel/transport_test.exs index a5c29c9818..b7d48d2f8e 100644 --- a/test/phoenix/channel/transport_test.exs +++ b/test/phoenix/channel/transport_test.exs @@ -1,44 +1,45 @@ -# TODO: We need simpler unit tests that -# do pass through the whole endpoint defmodule Phoenix.Channel.TransportTest do use ExUnit.Case, async: true use RouterHelper - alias __MODULE__.Endpoint - - defmodule Endpoint do - use Phoenix.Endpoint, otp_app: :transport_app - plug :check_origin - plug :render - defp check_origin(conn, _) do - allowed_origins = ["//example.com", "http://scheme.com", "//port.com:81"] - Phoenix.Channel.Transport.check_origin(conn, allowed_origins) - end - defp render(conn, _), - do: send_resp(conn, 200, "ok") + alias Phoenix.Channel.Transport + + def config(:url) do + [host: "host.com"] end - setup_all do - Endpoint.start_link() - :ok + defp check_origin(origin, origins) do + conn = conn(:get, "/") |> put_req_header("origin", origin) + Transport.check_origin(conn, __MODULE__, origins) end - defp call(origin) do - conn(:get, "/") - |> put_req_header("origin", origin) - |> Endpoint.call([]) + test "does not check origin if disabled" do + refute check_origin("/", false).halted end - test "does not check origin if none is given" do - conn = conn(:get, "/") |> Endpoint.call([]) - assert conn.status == 200 + test "checks origin against host" do + refute check_origin("https://host.com/", true).halted + conn = check_origin("https://another.com/", true) + assert conn.halted + assert conn.status == 403 end - test "check the origin of requests against allowed origins" do - assert call("https://example.com").status == 200 - assert call("http://port.com:81").status == 200 - assert call("http://notallowed.com").status == 403 - assert call("https://scheme.com").status == 403 - assert call("http://port.com:82").status == 403 + test "checks the origin of requests against allowed origins" do + origins = ["//example.com", "http://scheme.com", "//port.com:81"] + + refute check_origin("https://example.com/", origins).halted + refute check_origin("http://port.com:81/", origins).halted + + conn = check_origin("http://notallowed.com/", origins) + assert conn.halted + assert conn.status == 403 + + conn = check_origin("https://scheme.com/", origins) + assert conn.halted + assert conn.status == 403 + + conn = check_origin("http://port.com:82/", origins) + assert conn.halted + assert conn.status == 403 end end diff --git a/test/phoenix/integration/channel_test.exs b/test/phoenix/integration/channel_test.exs index ad6eec7f69..c0ed81ad15 100644 --- a/test/phoenix/integration/channel_test.exs +++ b/test/phoenix/integration/channel_test.exs @@ -73,10 +73,10 @@ defmodule Phoenix.Integration.ChannelTest do transport :longpoll, Phoenix.Transports.LongPoll, window_ms: 200, - origins: ["//example.com"] + check_origin: ["//example.com"] transport :websocket, Phoenix.Transports.WebSocket, - origins: ["//example.com"] + check_origin: ["//example.com"] def connect(%{"reject" => "true"}, _socket) do :error diff --git a/test/phoenix/socket_test.exs b/test/phoenix/socket_test.exs index 5cb4d95c38..142bd2f834 100644 --- a/test/phoenix/socket_test.exs +++ b/test/phoenix/socket_test.exs @@ -60,22 +60,18 @@ defmodule Phoenix.SocketTest do end test "__transports__" do - assert UserSocket.__transports__() == %{ - longpoll: {Phoenix.Transports.LongPoll, - [window_ms: 10000, pubsub_timeout_ms: 1000, serializer: Phoenix.Transports.LongPollSerializer, - log: false, crypto: [iterations: 1000, length: 32, digest: :sha256, cache: Plug.Keys]]}, - websocket: {Phoenix.Transports.WebSocket, - [timeout: 1234, serializer: Phoenix.Transports.WebSocketSerializer, log: false]} - } + assert %{longpoll: {Phoenix.Transports.LongPoll, _}, + websocket: {Phoenix.Transports.WebSocket, _}} = UserSocket.__transports__() end test "transport config is exposted and merged with prior registrations" do ws = {Phoenix.Transports.WebSocket, - [timeout: 1234, serializer: Phoenix.Transports.WebSocketSerializer, log: false]} + [timeout: 1234, serializer: Phoenix.Transports.WebSocketSerializer, + log: false, check_origin: true]} lp = {Phoenix.Transports.LongPoll, [window_ms: 10000, pubsub_timeout_ms: 1000, serializer: Phoenix.Transports.LongPollSerializer, - log: false, crypto: [iterations: 1000, length: 32, digest: :sha256, cache: Plug.Keys]]} + log: false, check_origin: true, crypto: [max_age: 1209600]]} assert UserSocket.__transport__(:websocket) == ws assert UserSocket.__transport__(:longpoll) == lp