diff --git a/src/fd_ws.erl b/src/fd_ws.erl index c1ac61d..175eac2 100644 --- a/src/fd_ws.erl +++ b/src/fd_ws.erl @@ -8,7 +8,10 @@ -export([ handshake/1, - response_token/1 + response_token/1, + recv/2, + send/2, + pong/1, ]). -include("http.hrl"). @@ -22,9 +25,7 @@ Result :: {ok, ClientProtocols, ClientExtensions, DraftResponse} | {error, Reason}, ClientProtocols :: [binary()], - ClientExtensions :: [Extension], - Extension :: Naked :: binary(), - | Pair :: {binary(), binary()}, + ClientExtensions :: binary(), DraftResponse :: response(), Reason :: any(). % @doc @@ -47,6 +48,8 @@ % validating the Origin field, implementing cross-site-request-forgery, adding % the retarded web date, rendering the response, sending it over the socket, % etc. +% +% ClientExtensions only joins the <<"sec-websocket-extensions">> fields with ", " handshake(R = #request{method = get, headers = Hs}) -> %% downcase the headers because have to match on them @@ -57,11 +60,11 @@ handshake(_) -> -spec casefold_headers(Headers) -> DowncaseHeaders - when Proplist :: [{Key, Value}], - Key :: binary(), - Value :: binary(), - LCProplist :: [{LowercaseKey, Value}], - LowercaseKey :: binary(). + when Headers :: [{Key, Value}], + Key :: binary(), + Value :: binary(), + DowncaseHeaders :: [{LowercaseKey, Value}], + LowercaseKey :: binary(). % @private % casefold all the keys in the header because they're case insensitive @@ -80,9 +83,7 @@ casefold_headers(Headers) -> Result :: {ok, ClientProtocols, ClientExtensions, DraftResponse} | {error, Reason}, ClientProtocols :: [binary()], - ClientExtensions :: [Extension] - Extension :: binary() | Option, - Option :: {Key :: binary(), Value :: binary()}, + ClientExtensions :: binary(), DraftResponse :: response(), Reason :: any(). % @private @@ -98,8 +99,8 @@ handshake2(#request{headers = DowncaseHeaders}) -> ClientProtocols = client_protocols(MapHeaders), ClientExtensions = client_extensions(DowncaseHeaders), MaybeResponseToken = validate_headers(MapHeaders), - case {ClientExtensions, MaybeResponseToken} of - {{ok, Extensions}, {ok, ResponseToken}} -> + case MaybeResponseToken of + {ok, ResponseToken} -> DraftResponse = #response{code = 101, slogan = "Switching Protocols", @@ -107,11 +108,9 @@ handshake2(#request{headers = DowncaseHeaders}) -> {"Connection", "Upgrade"}, {"Upgrade", "websocket"}]}, {ok, ClientProtocols, - Extensions, + ClientExtensions, DraftResponse}; - {{ok, _, _}, Error} -> - Error; - {Error, _} -> + Error -> Error end. @@ -126,7 +125,6 @@ handshake2(#request{headers = DowncaseHeaders}) -> client_protocols(FuckedHeaders) -> unfuck_protocol_string(FuckedHeaders, []). - unfuck_protocol_string([{<<"sec-websocket-protocol">>, Part} | Rest], Acc) -> unfuck_protocol_string(Rest, [Part | Acc]); unfuck_protocol_string([_ | Rest], Acc) -> @@ -143,133 +141,18 @@ unfuck_protocol_string([], PartsRev) -> --spec client_extensions(Headers) -> Result +-spec client_extensions(Headers) -> binary() when Headers :: [{Key, Val}], Key :: binary(), - Val :: binary(), - Result :: {ok, Extensions} - | {error, Reason}, - Extensions :: [Extension], - Extension :: binary() - | {Key, Val}, - Reason :: any(). + Val :: binary(). client_extensions(DowncaseHeaders) -> - UnfuckedExtensionsStr = unfuck_extensions_string(DowncaseHeaders, []), - client_extensions2(UnfuckedExtensionsStr). + unfuck_extensions_string(DowncaseHeaders, []). --spec client_extensions2(UnfuckedExtensionsStr) -% > Note that like other HTTP header fields, this header field MAY be -% > split or combined across multiple lines. Ergo, the following are -% > equivalent: -% > -% > Sec-WebSocket-Extensions: foo -% > Sec-WebSocket-Extensions: bar; baz=2 -% > -% > is exactly equivalent to -% > -% > Sec-WebSocket-Extensions: foo, bar; az=2 -% -% Un is the unfucked (i.e. last version of that) -% -% it may be empty, meaning either -% -% 1. the client was being inarticulate, meaning they sent a -% "Sec-Websocket-Extensions: \r\n" header -% 2. there was no such header -% -% strictly speaking, we're supposed to close the connection if the client is -% being inarticulate. I don't feel like coding all that complexity in, so we're -% just going to treat that exactly as if the client had just not sent -% "sec-websocket-extensions". -client_extensions2(<<>>) -> - {ok, {[], []}}; -client_extensions2(UnfuckedExtensionsStr) -> - case string:split(UnfuckedExtensionsStr, ";", all) of - [CommaFields] -> - {ok, unfuck_comma_fields(CommaFields), []}; - [CommaFields | OptFields] -> - Extensions = unfuck_comma_fields(CommaFields), - case unfuck_options(OptFields, []) of - {ok, Options} -> {ok, Extensions, Options}; - Error -> Error - end; - _ -> - {error, {bad_extensions, UnfuckedExtensionsStr}} - end. - -% <<"hello, world, blah">> -> [<<"hello">>, <<"world">>, <<"blah">>] -unfuck_comma_fields(CSV0) -> - CSV1 = string:trim(CSV0), - Fields = string:split(CSV1, ","), - lists:map(fun circumcise/1, Fields). - - -circumcise(String) -> - unicode:characters_to_binary(string:trim(String)). - -% [<<"foo">>, <<"bar=baz">>, <<"quux">>, <<"fuzz=\"fizz\"">>] -> -% [<<"foo">>, {<<"bar">>, <<"baz">>}, <<"quux">>, {<<"fuzz">>, <<"fizz">>}] -unfuck_options([Optstr | Rest], Acc) -> - case unfuck_option(Optstr) of - {ok, Opt} -> unfuck_options(Rest, [Opt | Acc]); - Error -> Error - end; -unfuck_options([], Acc) -> - lists:reverse(Acc). - - -% <<"foo=bar">> -> {<<"foo", "bar">>} -unfuck_option(Str) -> - case string:split(Str, "=") of - [K, V] -> - case unfuck_val(circumcise(V)) of - {ok, Val} -> {ok, {circumcise(K), Val}}; - Error -> Error - end; - [Opt] -> {ok, circumcise(Opt)}; - _ -> {error, {bad_extension_param, Str}} - end. - - -% val can either be a naked string or a quoted string -unfuck_val(Whole = <<$":8, Rest/binary>>) -> - unquote(Whole, Rest, <<>>); -unfuck_val(X) -> - {ok, X}. - - - --spec unquote(Orig, Parsing, Acc) -> Result - when Orig :: binary(), - Parsing :: binary(), - Acc :: binary(), - Result :: {ok, Unquoted} - | {error, Reason} - Unquoted :: binary(), - Reason :: any(). -% @private -% take the shit out of the quotes - -% trailing quote -> success -unquote(_, <<$">>, Acc) -> - {ok, Acc}; -% trailing quote and more stuff -> error -unquote(Orig, <<$", _/binary>>, _) -> - {error, {bad_extension_param, Orig}}; -% end of string before trailing quote -unquote(Orig, <<>>, _) -> - {error, {bad_extension_param, Orig}}; -unquote(Orig, <>, Acc) -> - unquote(Orig, Rest, <>). - - - --spec unfuck_extensions_string(KVPairs) -> Unfucked - when KVPairs :: [{Key, Val}], - Key :: binary(), - Val :: binary(), +-spec unfuck_extensions_string(KVPairs, Acc) -> Unfucked + when KVPairs :: [{Key :: binary(), Val :: binary()}], + Acc :: Unfucked, Unfucked :: binary(). % @private % quoth section 9.1: https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 @@ -300,16 +183,25 @@ unfuck_extensions_string([], PartsRev) -> +-spec circumcise(unicode:chardata()) -> binary(). +% @private delete leading/trailing whitespace then convert to binary + +circumcise(String) -> + unicode:characters_to_binary(string:trim(String)). + + + -spec validate_headers(HeadersMap) -> Result when HeadersMap :: #{Key :: binary() := Val :: binary()}, Result :: {ok, ResponseToken} - | {error, Reason} + | {error, Reason}, ResponseToken :: binary(), Reason :: any(). % @private % validate: % Upgrade: websocket -% Connection: +% Connection: Upgrade +% Sec-WebSocket-Version: 13 validate_headers(#{<<"sec-websocket-key">> := ChallengeToken, <<"sec-websocket-version">> := WS_Vsn, @@ -351,7 +243,7 @@ bad_connection(Str) -> --spec bad_version(binary()) +-spec bad_version(binary()) -> true | false. % @private version must be EXACTLY <<"13">> bad_version(<<"13">>) -> false; @@ -391,3 +283,204 @@ response_token(ChallengeToken) when is_binary(ChallengeToken) -> ConcatString = <>, Sha1 = crypto:hash(sha, ConcatString), base64:encode(Sha1). + + + +-type opcode() :: continuation + | text + | binary + | close + | ping + | pong. + +-record(frame, + {fin = none :: none | boolean(), + rsv = none :: none | <<_:3>>, + opcode = none :: none | opcode(), + mask = none :: none | boolean(), + payload_length = none :: none | non_neg_integer(), + masking_key = none :: none | <<_:256>>, + payload = none :: none | binary()}). + +-type frame() :: #frame{}. + + + +-spec recv(Socket, Received) -> Result + when Socket :: gen_tcp:socket(), + Received :: binary(), + Result :: {ok, binary()} + | {error, Reason} + Reason :: any(). + +recv(Sock, Recv) -> + recv(#frame{}, Sock, Recv). + + +%% frame: 1 bit +recv_frame(Frame = #frame{fin = none}, Sock, <>) -> + NewFin = + case FinBit of + 0 -> false; + 1 -> true + end, + NewFrame = Frame#frame{fin = NewFin}, + recv_frame(NewFrame, Sock, Rest); +recv_frame(Frame = #frame{fin = none}, Sock, <<>>) -> + case inet:setopts(Sock, [{active, once}]) of + ok -> + receive + {tcp, Sock, Bin} -> recv_frame(Frame, Sock, Bin); + {tcp_closed, Socket} -> {error, tcp_closed}; + {tcp_error, Socket, Reason} -> {error, {tcp_error, Reason}} + after 3000 -> + {error, timeout} + end; + {error, Reason} -> + {error, {inet, Reason}} + end; +%% rsv: 3 bits +recv_frame(Frame = #frame{rsv = none}, Sock, <>) -> + NewFrame = Frame#frame{rsv = RSV}, + recv_frame(NewFrame, Sock, Rest); +recv_frame(Frame = #frame{rsv = none}, Sock, Received) -> + case inet:setopts(Sock, [{active, once}]) of + ok -> + receive + {tcp, Sock, Bin} -> recv_frame(Frame, Sock, <>); + {tcp_closed, Socket} -> {error, tcp_closed}; + {tcp_error, Socket, Reason} -> {error, {tcp_error, Reason}} + after 3000 -> + {error, timeout} + end; + {error, Reason} -> + {error, {inet, Reason}} + end; +%% opcode +recv_frame(Frame = #frame{opcode = none}, Sock, <>) -> + if + OpcodeInt =:= + end; + + +-spec send(Socket, Payload) -> Result + when Socket :: gen_tcp:socket(), + Payload :: iodata(), + Result :: ok + | {error, Reason}, + Reason :: closed | {timeout, RestData} | inet:posix(), + RestData :: binary() | erlang:iovec(). +% @doc +% send binary data over Socket. handles frame nonsense +% @end + +send(Socket, Payload) -> + BPayload = unicode:characters_to_binary(Payload), + Frame = payload_to_frame(BPayload), + send_frame(Socket, Frame). + + +payload_to_frame(Payload) when byte_size(Payload) < (1 bsl 64) -> + #frame{fin = true, + opcode = binary, + mask = false, + payload_length = byte_size(Payload), + masking_key = none, + payload = Payload}. + + + +-spec send_frame(Sock, Frame) -> Result + when Sock :: gen_tcp:socket(), + Frame :: frame(), + Result :: ok + | {error, Reason}, + Reason :: closed | {timeout, RestData} | inet:posix(), + RestData :: binary() | erlang:iovec(). +% @private +% send a frame on the socket +% @end + +send_frame(Sock, Frame) -> + Binary = render_frame(Frame), + gen_tcp:send(Sock, Binary). + + + +-spec render_frame(Frame) -> Binary + when Frame :: frame(), + Binary :: binary(). +% @private +% render a frame +% @end + +render_frame(#frame{fin = Fin, + opcode = Opcode, + payload_length = Len, + payload = Payload}) -> + BFin = + case Fin of + true -> <<1:1>>; + false -> <<0:1>> + end, + BRSV = <<0:3>>, + BOpcode = + case Opcode of + continuation -> << 0:1>>; + text -> << 1:1>>; + binary -> << 2:1>>; + close -> << 8:1>>; + ping -> << 9:1>>; + pong -> <<10:1>> + end, + BMask = <<0:1>>, + BPayloadLength = render_payload_length(Len), + <>. + + + +-spec render_payload_length(non_neg_integer()) -> binary(). +% @private +% > Payload length: 7 bits, 7+16 bits, or 7+64 bits +% > +% > The length of the "Payload data", in bytes: if 0-125, that is the +% > payload length. If 126, the following 2 bytes interpreted as a +% > 16-bit unsigned integer are the payload length. If 127, the +% > following 8 bytes interpreted as a 64-bit unsigned integer (the +% > most significant bit MUST be 0) are the payload length. Multibyte +% > length quantities are expressed in network byte order. Note that +% > in all cases, the minimal number of bytes MUST be used to encode +% > the length, for example, the length of a 124-byte-long string +% > can't be encoded as the sequence 126, 0, 124. The payload length +% > is the length of the "Extension data" + the length of the +% > "Application data". The length of the "Extension data" may be +% > zero, in which case the payload length is the length of the +% > "Application data". + +render_payload_length(Len) when 0 =< Len, Len =< 125 -> + <>; +render_payload_length(Len) when 126 =< Len, Len =< 2#1111_1111_1111_1111 -> + <<126:7, Len:16>>; +render_payload_length(Len) when (1 bsl 16) =< Len, Len < (1 bsl 63) -> + <<127:7, Len:64>>. + + + +-spec pong(Socket) -> Result + when Socket :: gen_tcp:socket(), + Result :: ok + | {error, Reason}, + Reason :: closed | {timeout, RestData} | inet:posix(), + RestData :: binary() | erlang:iovec(). + +pong(Sock) -> + Frame = #frame{fin = true, + opcode = pong, + payload_length = 0, + payload = <<>>}, + send_frame(Sock, Frame).