diff --git a/src/fd_ws.erl b/src/fd_ws.erl index 175eac2..8094de2 100644 --- a/src/fd_ws.erl +++ b/src/fd_ws.erl @@ -4,20 +4,58 @@ -module(fd_ws). -export_type([ + opcode/0, + frame/0, + ws_msg/0 ]). -export([ handshake/1, - response_token/1, recv/2, send/2, - pong/1, + pong/1, pong/2 ]). -include("http.hrl"). -type request() :: #request{}. -type response() :: #response{}. +-type tcp_error() :: closed + | {timeout, RestData :: binary() | erlang:iovec()} + | inet:posix(). + +-define(MAX_PAYLOAD_SIZE, (1 bsl 63)). + +%% Frames +%% https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + +-type opcode() :: continuation + | text + | binary + | close + | ping + | pong. + +-record(frame, + {fin = none :: none | boolean(), + rsv = none :: none | <<_:3>>, + opcode = none :: none | opcode(), + mask = none :: none | boolean(), + payload_length = none :: none | non_neg_integer(), + masking_key = none :: none | <<>> | <<_:32>>, + payload = none :: none | binary()}). + +-type frame() :: #frame{}. + + +%% porcelain messages + +-type ws_msg() :: {text, Payload :: iodata()} + | {binary, Payload :: iodata()} + | {close, Payload :: iodata()} + | {ping, Payload :: iodata()} + | {pong, Payload :: iodata()}. + -spec handshake(Req) -> Result @@ -49,7 +87,23 @@ % the retarded web date, rendering the response, sending it over the socket, % etc. % -% ClientExtensions only joins the <<"sec-websocket-extensions">> fields with ", " +% The returned ClientExtensions is the result of joining the +% <<"sec-websocket-extensions">> fields with ", " +% +% quoth section 9.1: https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 +% +% > Note that like other HTTP header fields, this header field MAY be +% > split or combined across multiple lines. Ergo, the following are +% > equivalent: +% > +% > Sec-WebSocket-Extensions: foo +% > Sec-WebSocket-Extensions: bar; baz=2 +% > +% > is exactly equivalent to +% > +% > Sec-WebSocket-Extensions: foo, bar; baz=2 +% +% Nobody actually uses extensions, so how you choose to parse this is on you. handshake(R = #request{method = get, headers = Hs}) -> %% downcase the headers because have to match on them @@ -145,15 +199,6 @@ unfuck_protocol_string([], PartsRev) -> when Headers :: [{Key, Val}], Key :: binary(), Val :: binary(). - -client_extensions(DowncaseHeaders) -> - unfuck_extensions_string(DowncaseHeaders, []). - - --spec unfuck_extensions_string(KVPairs, Acc) -> Unfucked - when KVPairs :: [{Key :: binary(), Val :: binary()}], - Acc :: Unfucked, - Unfucked :: binary(). % @private % quoth section 9.1: https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 % @@ -172,6 +217,11 @@ client_extensions(DowncaseHeaders) -> % matches <<"sec-websocket-extensions">>, then csv its value to the thing % @end + +client_extensions(DowncaseHeaders) -> + unfuck_extensions_string(DowncaseHeaders, []). + + unfuck_extensions_string([{<<"sec-websocket-extensions">>, Part} | Rest], Acc) -> unfuck_extensions_string(Rest, [Part | Acc]); unfuck_extensions_string([_ | Rest], Acc) -> @@ -286,37 +336,161 @@ response_token(ChallengeToken) when is_binary(ChallengeToken) -> --type opcode() :: continuation - | text - | binary - | close - | ping - | pong. - --record(frame, - {fin = none :: none | boolean(), - rsv = none :: none | <<_:3>>, - opcode = none :: none | opcode(), - mask = none :: none | boolean(), - payload_length = none :: none | non_neg_integer(), - masking_key = none :: none | <<_:256>>, - payload = none :: none | binary()}). - --type frame() :: #frame{}. - - - -spec recv(Socket, Received) -> Result - when Socket :: gen_tcp:socket(), - Received :: binary(), - Result :: {ok, binary()} - | {error, Reason} - Reason :: any(). + when Socket :: gen_tcp:socket(), + Received :: binary(), + Result :: {ok, Message, Frames, Remainder} + | {error, Reason}, + Message :: ws_msg(), + Frames :: [frame()], + Remainder :: binary(), + Reason :: any(). +% @doc +% Equivalent to recv(Socket, Received, []) recv(Sock, Recv) -> - recv(#frame{}, Sock, Recv). + recv(Sock, Recv, []). + +-spec recv(Socket, Received, Frames) -> Result + when Socket :: gen_tcp:socket(), + Received :: binary(), + Frames :: [frame()], + Result :: {ok, Message, NewFrames, Remainder} + | {error, Reason}, + Message :: ws_msg(), + NewFrames :: Frames, + Remainder :: binary(), + Reason :: any(). +% @doc +% Equivalent to recv(Socket, Received, []) + +recv(Sock, Received, Frames) -> + case maybe_pop_msg(Frames) of + {ok, Message, NewFrames} -> + {ok, Message, NewFrames, Received}; + incomplete -> + case recv_frame(#frame{}, Sock, Received) of + {ok, Frame, NewReceived} -> + NewFrames = [Frame | Frames], + recv(Sock, NewReceived, NewFrames); + Error -> + Error + end; + Error -> + Error + end. + + +-spec maybe_pop_msg(Frames) -> Result + when Frames :: [frame()], + Result :: {ok, Message, NewFrames} + | incomplete + | {error, Reason}, + Message :: ws_msg(), + NewFrames :: Frames, + Reason :: any(). +% @doc +% try to parse the stack of frames into a single message +% +% ignores RSV bits +% @end + +maybe_pop_msg([]) -> + incomplete; +% case 1: control frames +maybe_pop_msg([Frame = #frame{opcode = Opcode} | Frames]) + when Opcode =:= close; Opcode =:= ping; Opcode =:= pong -> + case maybe_control_msg(Frame) of + {ok, Msg} -> {ok, Msg, Frames}; + Error -> Error + end; +maybe_pop_msg(_) -> + error(nyi). + + + +-spec maybe_control_msg(Frame) -> Result + when Frame :: frame(), + Result :: {ok, Message} + | {error, Reason}, + Message :: ws_msg(), + Reason :: any(). +% @private +% assume the frame is a control frame, validate it, and unmask the payload + +maybe_control_msg(F = #frame{fin = true, + opcode = Opcode, + mask = Mask, + payload_length = Len, + masking_key = Key, + payload = Payload}) + when ((Opcode =:= close) orelse (Opcode =:= ping) orelse (Opcode =:= pong)) + andalso (Len =< 125) -> + case maybe_unmask(F, Mask, Key, Payload) of + {ok, UnmaskedPayload} -> + Msg = {Opcode, UnmaskedPayload}, + {ok, Msg}; + Error -> + Error + end; +maybe_control_msg(F) -> + {error, {illegal_frame, F}}. + + + +-spec maybe_unmask(Frame, Mask, Key, Payload) -> Result + when Frame :: frame(), + Mask :: boolean(), + Key :: <<>> | <<_:32>>, + Payload :: binary(), + Result :: {ok, Unmasked} + | {error, Reason}, + Unmasked :: binary(), + Reason :: any(). +% @private +% unmask the payload +% @end + +% eliminate invalid pairs of {mask, masking_key} +maybe_unmask(_, true, <>, Payload) -> {ok, mask_unmask(Key, Payload)}; +maybe_unmask(_, false, <<>>, Payload) -> {ok, Payload}; +maybe_unmask(F, true, <<>>, _) -> {error, {illegal_frame, F}}; +maybe_unmask(F, false, <<_:4/bytes>>, _) -> {error, {illegal_frame, F}}. + + +%% invertible +%% see: https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 +mask_unmask(Key = <<_:4/bytes>>, Payload) -> + mu(Key, Key, Payload, <<>>). + +% essentially this is a modular zipWith xor of the masking key with the payload +mu(Key, <>, <>, Acc) -> + NewByte = KeyByte bxor PayloadByte, + NewAcc = <>, + mu(Key, KeyRest, PayloadRest, NewAcc); +% this is the case where we need to refresh the active key +mu(Key, <<>>, Payload, Acc) -> + mu(Key, Key, Payload, Acc); +% done +mu(_, _, <<>>, Acc) -> + Acc. + + + +-spec recv_frame(Parsed, Socket, Received) -> Result + when Parsed :: frame(), + Socket :: gen_tcp:socket(), + Received :: bitstring(), + Result :: {ok, frame(), Remainder} + | {error, Reason}, + Remainder :: bitstring(), + Reason :: any(). +% @private +% parse a single frame off the socket +% @end + %% frame: 1 bit recv_frame(Frame = #frame{fin = none}, Sock, <>) -> NewFin = @@ -326,41 +500,107 @@ recv_frame(Frame = #frame{fin = none}, Sock, <>) -> end, NewFrame = Frame#frame{fin = NewFin}, recv_frame(NewFrame, Sock, Rest); -recv_frame(Frame = #frame{fin = none}, Sock, <<>>) -> - case inet:setopts(Sock, [{active, once}]) of - ok -> - receive - {tcp, Sock, Bin} -> recv_frame(Frame, Sock, Bin); - {tcp_closed, Socket} -> {error, tcp_closed}; - {tcp_error, Socket, Reason} -> {error, {tcp_error, Reason}} - after 3000 -> - {error, timeout} - end; - {error, Reason} -> - {error, {inet, Reason}} - end; +recv_frame(Frame = #frame{fin = none}, Sock, Received = <<>>) -> + recv_frame_await(Frame, Sock, Received); %% rsv: 3 bits recv_frame(Frame = #frame{rsv = none}, Sock, <>) -> NewFrame = Frame#frame{rsv = RSV}, recv_frame(NewFrame, Sock, Rest); recv_frame(Frame = #frame{rsv = none}, Sock, Received) -> + recv_frame_await(Frame, Sock, Received); +%% opcode: 4 bits +recv_frame(Frame = #frame{opcode = none}, Sock, <>) -> + Opcode = + case OpcodeInt of + 0 -> continuation; + 1 -> text; + 2 -> binary; + 8 -> close; + 9 -> ping; + 10 -> pong; + _ -> bad_opcode + end, + case Opcode of + bad_opcode -> + {error, {bad_opcode, OpcodeInt}}; + _ -> + NewFrame = Frame#frame{opcode = Opcode}, + recv_frame(NewFrame, Sock, Rest) + end; +recv_frame(Frame = #frame{opcode = none}, Sock, Received) -> + recv_frame_await(Frame, Sock, Received); +%% mask: 1 bit +recv_frame(Frame = #frame{mask = none}, Sock, <>) -> + NewMask = + case MaskBit of + 0 -> false; + 1 -> true + end, + NewFrame = Frame#frame{mask = NewMask}, + recv_frame(NewFrame, Sock, Rest); +recv_frame(Frame = #frame{mask = none}, Sock, Received = <<>>) -> + recv_frame_await(Frame, Sock, Received); +%% payload length: variable (yay) +% first case: short length 0..125 +recv_frame(Frame = #frame{payload_length = none}, Sock, <>) when Len =< 125 -> + NewFrame = Frame#frame{payload_length = Len}, + recv_frame(NewFrame, Sock, Rest); +% second case: 126 -> 2 bytes to follow +recv_frame(Frame = #frame{payload_length = none}, Sock, <<126:7, Len:16, Rest/bits>>) -> + NewFrame = Frame#frame{payload_length = Len}, + recv_frame(NewFrame, Sock, Rest); +% third case: 127 -> 8 bytes to follow +% bytes must start with a 0 bit +recv_frame(_Frame = #frame{payload_length = none}, _Sock, <<127:7, 1:1, _/bits>>) -> + {error, {illegal_frame, "payload length >= 1 bsl 63"}}; +% 127, next is a legal length, continue +recv_frame(Frame = #frame{payload_length = none}, Sock, <<127:7, Len:64, Rest/bits>>) -> + NewFrame = Frame#frame{payload_length = Len}, + recv_frame(NewFrame, Sock, Rest); +% otherwise wait +recv_frame(Frame = #frame{payload_length = none}, Sock, Received) -> + recv_frame_await(Frame, Sock, Received); +%% masking key: 0 or 4 bits +% not expecting a masking key, fill in that field here +recv_frame(Frame = #frame{mask = false, masking_key = none}, Sock, Received) -> + NewFrame = Frame#frame{masking_key = <<>>}, + recv_frame(NewFrame, Sock, Received); +% expecting one +recv_frame(Frame = #frame{mask = true, masking_key = none}, Sock, <>) -> + NewFrame = Frame#frame{masking_key = Key}, + recv_frame(NewFrame, Sock, Rest); +% not found +recv_frame(Frame = #frame{mask = true, masking_key = none}, Sock, Received) -> + recv_frame_await(Frame, Sock, Received); +%% payload +recv_frame(Frame = #frame{payload_length = Len, payload = none}, Sock, Received) when is_integer(Len) -> + case Received of + % we have enough bytes + <> -> + FinalFrame = Frame#frame{payload = Payload}, + {ok, FinalFrame, Rest}; + % we do not have enough bytes + _ -> + recv_frame_await(Frame, Sock, Received) + end. + + + +%% factoring this out into a function to reduce repetition +recv_frame_await(Frame, Sock, Received) -> case inet:setopts(Sock, [{active, once}]) of ok -> receive - {tcp, Sock, Bin} -> recv_frame(Frame, Sock, <>); - {tcp_closed, Socket} -> {error, tcp_closed}; - {tcp_error, Socket, Reason} -> {error, {tcp_error, Reason}} + {tcp, Sock, Bin} -> recv_frame(Frame, Sock, <>); + {tcp_closed, Sock} -> {error, tcp_closed}; + {tcp_error, Sock, Reason} -> {error, {tcp_error, Reason}} after 3000 -> {error, timeout} end; {error, Reason} -> {error, {inet, Reason}} - end; -%% opcode -recv_frame(Frame = #frame{opcode = none}, Sock, <>) -> - if - OpcodeInt =:= - end; + end. + -spec send(Socket, Payload) -> Result @@ -372,6 +612,10 @@ recv_frame(Frame = #frame{opcode = none}, Sock, <>) -> RestData :: binary() | erlang:iovec(). % @doc % send binary data over Socket. handles frame nonsense +% +% types the payload as bytes +% +% max payload size is 2^64 - 1 bytes % @end send(Socket, Payload) -> @@ -380,7 +624,7 @@ send(Socket, Payload) -> send_frame(Socket, Frame). -payload_to_frame(Payload) when byte_size(Payload) < (1 bsl 64) -> +payload_to_frame(Payload) when byte_size(Payload) < ?MAX_PAYLOAD_SIZE -> #frame{fin = true, opcode = binary, mask = false, @@ -395,8 +639,7 @@ payload_to_frame(Payload) when byte_size(Payload) < (1 bsl 64) -> Frame :: frame(), Result :: ok | {error, Reason}, - Reason :: closed | {timeout, RestData} | inet:posix(), - RestData :: binary() | erlang:iovec(). + Reason :: tcp_error(). % @private % send a frame on the socket % @end @@ -412,6 +655,14 @@ send_frame(Sock, Frame) -> Binary :: binary(). % @private % render a frame +% +% TODO: this doesn't check/do masking +% +% This is a non-issue as long as this is only used for rendering messages sent +% from server to client (unmasked per protocol). However, for debugging +% purposes, a user of this library might want to test how frames render with +% masking. This functionality is not currently supported, but is a planned +% addition in the future. % @end render_frame(#frame{fin = Fin, @@ -466,7 +717,7 @@ render_payload_length(Len) when 0 =< Len, Len =< 125 -> <>; render_payload_length(Len) when 126 =< Len, Len =< 2#1111_1111_1111_1111 -> <<126:7, Len:16>>; -render_payload_length(Len) when (1 bsl 16) =< Len, Len < (1 bsl 63) -> +render_payload_length(Len) when (1 bsl 16) =< Len, Len < ?MAX_PAYLOAD_SIZE -> <<127:7, Len:64>>. @@ -479,8 +730,21 @@ render_payload_length(Len) when (1 bsl 16) =< Len, Len < (1 bsl 63) -> RestData :: binary() | erlang:iovec(). pong(Sock) -> + pong(Sock, <<>>). + + + +-spec pong(Socket, Payload) -> Result + when Socket :: gen_tcp:socket(), + Payload :: binary(), + Result :: ok + | {error, Reason}, + Reason :: closed | {timeout, RestData} | inet:posix(), + RestData :: binary() | erlang:iovec(). + +pong(Sock, Payload) when is_binary(Payload), byte_size(Payload) < ?MAX_PAYLOAD_SIZE -> Frame = #frame{fin = true, opcode = pong, - payload_length = 0, - payload = <<>>}, + payload_length = byte_size(Payload), + payload = Payload}, send_frame(Sock, Frame).