diff --git a/src/fd_ws.erl b/src/fd_ws.erl new file mode 100644 index 0000000..8ff67b6 --- /dev/null +++ b/src/fd_ws.erl @@ -0,0 +1,159 @@ +% @doc websockets +% +% ref: https://datatracker.ietf.org/doc/html/rfc6455 +-module(fd_ws). + +-export_type([ +]). + +-export([ + handshake/1, + response_token/1 +]). + +-include("http.hrl"). + +-type request() :: #request{}. +-type response() :: #response{}. + + +-spec handshake(Req) -> Result + when Req :: request(), + Result :: {ok, ClientProtocols, ClientExtensions, DraftResponse} + | {error, Reason}, + ClientProtocols :: [binary()], + ClientExtensions :: [binary()], + DraftResponse :: response(), + Reason :: any(). +% @doc +% This mostly just validates that all the 't's have been dotted and 'i's have +% been crossed. +% +% given an HTTP request: +% +% - if it is NOT a valid websocket handshake request, error +% - if it IS a valid websocket handshake request, form an initial candidate +% response record with the following fields: +% +% code = 101 +% slogan = "Switching Protocols" +% headers = [{"Sec-WebSocket-Accept", ChallengeResponse}, +% {"Connection", "Upgrade"}, +% {"Upgrade", "websocket"}]. +% +% YOU are responsible for dealing with any cookie logic, adding the retarded +% web date, rendering the response, etc. + +handshake(R = #request{method = get, headers = Hs}) -> + %% downcase the headers because have to match on them + handshake2(R#request{headers = casefold_headers(Hs)}); +handshake(_) -> + {error, bad_request}. + + +casefold_headers([{K, V} | Rest]) -> + [{unicode:characters_to_binary(string:casefold(K)), V} | casefold_headers(Rest)]; +casefold_headers([]) -> + []. + +handshake2(Req = #request{headers = DowncaseHeaders}) -> + % headers MUST contain fields: + % sec-websocket-key: _ % arbitrary + % sec-websocket-version: 13 % must be EXACTLY 13 + % connection: Upgrade % must include the token "Upgrade" + % upgrade: websocket % must include the token "websocket" + MapHeaders = maps:from_list(DowncaseHeaders) + ClientProtocols = client_protocols(MapHeaders), + ClientExtensions = client_extensions(MapHeaders), + case validate_headers(MapHeaders) of + {ok, ResponseToken} -> + {ok, ClientProtocols, + ClientExtensions, + #response{code = 101, + slogan = "Switching Protocols", + headers = [{"Sec-WebSocket-Accept", ResponseToken}, + {"Connection", "Upgrade"}, + {"Upgrade", "websocket"}]}}; + Error -> + Error + end. + +client_protocols(#{<<"sec-websocket-protocol">> := CommaSeparatedProtocols}) -> + Protocols = string:split(CommaSeparatedProtocols, ",", all), + Clean = + fun(String) -> + unicode:characters_to_binary(string:trim(String)) + end, + lists:map(Clean, Protocols). + +client_extensions(#{<<"sec-websocket-extensions">> := CommaSeparatedExtensions}) -> + Extensions = string:split(CommaSeparatedExtensions, ",", all), + Clean = + fun(String) -> + unicode:characters_to_binary(string:trim(String)) + end, + lists:map(Clean, Extensions). + + +validate_headers(#{<<"sec-websocket-key">> := ChallengeToken, + <<"sec-websocket-version">> := WS_Vsn, + <<"connection">> := Connection, + <<"upgrade">> := Upgrade}) -> + if + bad_upgrade(Upgrade) -> {error, {bad_upgrade, Upgrade}}; + bad_connection(Connection) -> {error, {bad_connection, Connection}}; + bad_version(WS_Vsn) -> {error, {bad_version, WS_Vsn}}; + true -> {ok, response_token(ChallengeToken)} + end. + +% string must include "websocket" as a token +bad_upgrade(Str) -> + case string:find(Str, "websocket") of + nomatch -> true; + _ -> false + end. + +% string must include "Upgrade" as a token +bad_connection(Str) -> + case string:find(Str, "Upgrade") of + nomatch -> true; + _ -> false + end. + +% version must be EXACTLY <<"13">> +bad_version(<<"13">> -> false; +bad_version(_) -> true. + + +-spec response_token(binary()) -> binary(). +% @doc +% Quoth the RFC: +% +% > Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +% > +% > For this header field, the server has to take the value (as present +% > in the header field, e.g., the base64-encoded [RFC4648] version minus +% > any leading and trailing whitespace) and concatenate this with the +% > Globally Unique Identifier (GUID, [RFC4122]) "258EAFA5-E914-47DA- +% > 95CA-C5AB0DC85B11" in string form, which is unlikely to be used by +% > network endpoints that do not understand the WebSocket Protocol. A +% > SHA-1 hash (160 bits) [FIPS.180-3], base64-encoded (see Section 4 of +% > [RFC4648]), of this concatenation is then returned in the server's +% > handshake. +% > +% > Concretely, if as in the example above, the |Sec-WebSocket-Key| +% > header field had the value "dGhlIHNhbXBsZSBub25jZQ==", the server +% > would concatenate the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +% > to form the string "dGhlIHNhbXBsZSBub25jZQ==258EAFA5-E914-47DA-95CA- +% > C5AB0DC85B11". The server would then take the SHA-1 hash of this, +% > giving the value 0xb3 0x7a 0x4f 0x2c 0xc0 0x62 0x4f 0x16 0x90 0xf6 +% > 0x46 0x06 0xcf 0x38 0x59 0x45 0xb2 0xbe 0xc4 0xea. This value is +% > then base64-encoded (see Section 4 of [RFC4648]), to give the value +% > "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=". This value would then be echoed in +% > the |Sec-WebSocket-Accept| header field. + +response_token(ChallengeToken) when is_binary(ChallengeToken) -> + MagicString = <<"258EAFA5-E914-47DA-95CA-C5AB0DC85B11">>, + ConcatString = <>, + Sha1 = crypto:hash(sha, ConcatString), + base64:encode(Sha1).