Compare commits
2 Commits
9107679dfc
...
35dbf06a55
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35dbf06a55 | ||
|
|
62d0710fcf |
409
src/fd_ws.erl
409
src/fd_ws.erl
@ -8,7 +8,10 @@
|
||||
|
||||
-export([
|
||||
handshake/1,
|
||||
response_token/1
|
||||
response_token/1,
|
||||
recv/2,
|
||||
send/2,
|
||||
pong/1,
|
||||
]).
|
||||
|
||||
-include("http.hrl").
|
||||
@ -22,7 +25,7 @@
|
||||
Result :: {ok, ClientProtocols, ClientExtensions, DraftResponse}
|
||||
| {error, Reason},
|
||||
ClientProtocols :: [binary()],
|
||||
ClientExtensions :: [binary()],
|
||||
ClientExtensions :: binary(),
|
||||
DraftResponse :: response(),
|
||||
Reason :: any().
|
||||
% @doc
|
||||
@ -41,88 +44,211 @@
|
||||
% {"Connection", "Upgrade"},
|
||||
% {"Upgrade", "websocket"}].
|
||||
%
|
||||
% YOU are responsible for dealing with any cookie logic, adding the retarded
|
||||
% web date, rendering the response, etc.
|
||||
% YOU are responsible for dealing with any cookie logic, authentication logic,
|
||||
% validating the Origin field, implementing cross-site-request-forgery, adding
|
||||
% the retarded web date, rendering the response, sending it over the socket,
|
||||
% etc.
|
||||
%
|
||||
% ClientExtensions only joins the <<"sec-websocket-extensions">> fields with ", "
|
||||
|
||||
handshake(R = #request{method = get, headers = Hs}) ->
|
||||
%% downcase the headers because have to match on them
|
||||
handshake2(R#request{headers = casefold_headers(Hs)});
|
||||
handshake(_) ->
|
||||
{error, bad_request}.
|
||||
{error, bad_method}.
|
||||
|
||||
|
||||
casefold_headers([{K, V} | Rest]) ->
|
||||
[{unicode:characters_to_binary(string:casefold(K)), V} | casefold_headers(Rest)];
|
||||
casefold_headers([]) ->
|
||||
[].
|
||||
|
||||
handshake2(Req = #request{headers = DowncaseHeaders}) ->
|
||||
-spec casefold_headers(Headers) -> DowncaseHeaders
|
||||
when Headers :: [{Key, Value}],
|
||||
Key :: binary(),
|
||||
Value :: binary(),
|
||||
DowncaseHeaders :: [{LowercaseKey, Value}],
|
||||
LowercaseKey :: binary().
|
||||
% @private
|
||||
% casefold all the keys in the header because they're case insensitive
|
||||
|
||||
casefold_headers(Headers) ->
|
||||
Downcase =
|
||||
fun({K, V}) ->
|
||||
NewKey = unicode:characters_to_binary(string:casefold(K)),
|
||||
{NewKey, V}
|
||||
end,
|
||||
lists:map(Downcase, Headers).
|
||||
|
||||
|
||||
|
||||
-spec handshake2(DowncaseReq) -> Result
|
||||
when DowncaseReq :: request(),
|
||||
Result :: {ok, ClientProtocols, ClientExtensions, DraftResponse}
|
||||
| {error, Reason},
|
||||
ClientProtocols :: [binary()],
|
||||
ClientExtensions :: binary(),
|
||||
DraftResponse :: response(),
|
||||
Reason :: any().
|
||||
% @private
|
||||
% we may assume (WMA) method=get and headers have all been downcased
|
||||
|
||||
handshake2(#request{headers = DowncaseHeaders}) ->
|
||||
% headers MUST contain fields:
|
||||
% sec-websocket-key: _ % arbitrary
|
||||
% sec-websocket-version: 13 % must be EXACTLY 13
|
||||
% connection: Upgrade % must include the token "Upgrade"
|
||||
% upgrade: websocket % must include the token "websocket"
|
||||
MapHeaders = maps:from_list(DowncaseHeaders)
|
||||
ClientProtocols = client_protocols(MapHeaders),
|
||||
ClientExtensions = client_extensions(MapHeaders),
|
||||
case validate_headers(MapHeaders) of
|
||||
MapHeaders = maps:from_list(DowncaseHeaders),
|
||||
ClientProtocols = client_protocols(MapHeaders),
|
||||
ClientExtensions = client_extensions(DowncaseHeaders),
|
||||
MaybeResponseToken = validate_headers(MapHeaders),
|
||||
case MaybeResponseToken of
|
||||
{ok, ResponseToken} ->
|
||||
{ok, ClientProtocols,
|
||||
ClientExtensions,
|
||||
DraftResponse =
|
||||
#response{code = 101,
|
||||
slogan = "Switching Protocols",
|
||||
headers = [{"Sec-WebSocket-Accept", ResponseToken},
|
||||
{"Connection", "Upgrade"},
|
||||
{"Upgrade", "websocket"}]}};
|
||||
{"Upgrade", "websocket"}]},
|
||||
{ok, ClientProtocols,
|
||||
ClientExtensions,
|
||||
DraftResponse};
|
||||
Error ->
|
||||
Error
|
||||
end.
|
||||
|
||||
client_protocols(#{<<"sec-websocket-protocol">> := CommaSeparatedProtocols}) ->
|
||||
Protocols = string:split(CommaSeparatedProtocols, ",", all),
|
||||
Clean =
|
||||
fun(String) ->
|
||||
unicode:characters_to_binary(string:trim(String))
|
||||
end,
|
||||
lists:map(Clean, Protocols).
|
||||
|
||||
client_extensions(#{<<"sec-websocket-extensions">> := CommaSeparatedExtensions}) ->
|
||||
Extensions = string:split(CommaSeparatedExtensions, ",", all),
|
||||
Clean =
|
||||
fun(String) ->
|
||||
unicode:characters_to_binary(string:trim(String))
|
||||
end,
|
||||
lists:map(Clean, Extensions).
|
||||
|
||||
-spec client_protocols(Headers) -> Protocols
|
||||
when Headers :: [{binary(), binary()}],
|
||||
Protocols :: [binary()].
|
||||
% @private
|
||||
% needs to loop through all the headers and unfuck multiline bullshit
|
||||
|
||||
client_protocols(FuckedHeaders) ->
|
||||
unfuck_protocol_string(FuckedHeaders, []).
|
||||
|
||||
unfuck_protocol_string([{<<"sec-websocket-protocol">>, Part} | Rest], Acc) ->
|
||||
unfuck_protocol_string(Rest, [Part | Acc]);
|
||||
unfuck_protocol_string([_ | Rest], Acc) ->
|
||||
unfuck_protocol_string(Rest, Acc);
|
||||
unfuck_protocol_string([], PartsRev) ->
|
||||
Parts = lists:reverse(PartsRev),
|
||||
% have to join everything together and then re-split
|
||||
CSVBin = unicode:characters_to_binary(lists:join(", ", Parts)),
|
||||
% after the surgery
|
||||
TrannyParts = string:split(CSVBin, ",", all),
|
||||
% trim the parts
|
||||
JewParts = lists:map(fun circumcise/1, TrannyParts),
|
||||
JewParts.
|
||||
|
||||
|
||||
|
||||
-spec client_extensions(Headers) -> binary()
|
||||
when Headers :: [{Key, Val}],
|
||||
Key :: binary(),
|
||||
Val :: binary().
|
||||
|
||||
client_extensions(DowncaseHeaders) ->
|
||||
unfuck_extensions_string(DowncaseHeaders, []).
|
||||
|
||||
|
||||
-spec unfuck_extensions_string(KVPairs, Acc) -> Unfucked
|
||||
when KVPairs :: [{Key :: binary(), Val :: binary()}],
|
||||
Acc :: Unfucked,
|
||||
Unfucked :: binary().
|
||||
% @private
|
||||
% quoth section 9.1: https://datatracker.ietf.org/doc/html/rfc6455#section-9.1
|
||||
%
|
||||
% > Note that like other HTTP header fields, this header field MAY be
|
||||
% > split or combined across multiple lines. Ergo, the following are
|
||||
% > equivalent:
|
||||
% >
|
||||
% > Sec-WebSocket-Extensions: foo
|
||||
% > Sec-WebSocket-Extensions: bar; baz=2
|
||||
% >
|
||||
% > is exactly equivalent to
|
||||
% >
|
||||
% > Sec-WebSocket-Extensions: foo, bar; baz=2
|
||||
%
|
||||
% basically have to go through the entire proplist of headers, and if it
|
||||
% matches <<"sec-websocket-extensions">>, then csv its value to the thing
|
||||
% @end
|
||||
|
||||
unfuck_extensions_string([{<<"sec-websocket-extensions">>, Part} | Rest], Acc) ->
|
||||
unfuck_extensions_string(Rest, [Part | Acc]);
|
||||
unfuck_extensions_string([_ | Rest], Acc) ->
|
||||
unfuck_extensions_string(Rest, Acc);
|
||||
unfuck_extensions_string([], PartsRev) ->
|
||||
% in the example above, PartsRev = [<<"bar; baz=2">>, <<"foo">>],
|
||||
% so need to reverse and then join with commas
|
||||
circumcise(lists:join(<<", ">>, lists:reverse(PartsRev))).
|
||||
|
||||
|
||||
|
||||
-spec circumcise(unicode:chardata()) -> binary().
|
||||
% @private delete leading/trailing whitespace then convert to binary
|
||||
|
||||
circumcise(String) ->
|
||||
unicode:characters_to_binary(string:trim(String)).
|
||||
|
||||
|
||||
|
||||
-spec validate_headers(HeadersMap) -> Result
|
||||
when HeadersMap :: #{Key :: binary() := Val :: binary()},
|
||||
Result :: {ok, ResponseToken}
|
||||
| {error, Reason},
|
||||
ResponseToken :: binary(),
|
||||
Reason :: any().
|
||||
% @private
|
||||
% validate:
|
||||
% Upgrade: websocket
|
||||
% Connection: Upgrade
|
||||
% Sec-WebSocket-Version: 13
|
||||
|
||||
validate_headers(#{<<"sec-websocket-key">> := ChallengeToken,
|
||||
<<"sec-websocket-version">> := WS_Vsn,
|
||||
<<"connection">> := Connection,
|
||||
<<"upgrade">> := Upgrade}) ->
|
||||
BadUpgrade = bad_upgrade(Upgrade),
|
||||
BadConnection = bad_connection(Connection),
|
||||
BadVersion = bad_version(WS_Vsn),
|
||||
if
|
||||
bad_upgrade(Upgrade) -> {error, {bad_upgrade, Upgrade}};
|
||||
bad_connection(Connection) -> {error, {bad_connection, Connection}};
|
||||
bad_version(WS_Vsn) -> {error, {bad_version, WS_Vsn}};
|
||||
true -> {ok, response_token(ChallengeToken)}
|
||||
end.
|
||||
BadUpgrade -> {error, {bad_upgrade, Upgrade}};
|
||||
BadConnection -> {error, {bad_connection, Connection}};
|
||||
BadVersion -> {error, {bad_version, WS_Vsn}};
|
||||
true -> {ok, response_token(ChallengeToken)}
|
||||
end;
|
||||
validate_headers(_) ->
|
||||
{error, bad_request}.
|
||||
|
||||
|
||||
|
||||
-spec bad_upgrade(binary()) -> true | false.
|
||||
% @private string must include "websocket" as a token
|
||||
|
||||
% string must include "websocket" as a token
|
||||
bad_upgrade(Str) ->
|
||||
case string:find(Str, "websocket") of
|
||||
nomatch -> true;
|
||||
_ -> false
|
||||
end.
|
||||
|
||||
% string must include "Upgrade" as a token
|
||||
|
||||
|
||||
-spec bad_connection(binary()) -> true | false.
|
||||
% @private string must include "Upgrade" as a token
|
||||
|
||||
bad_connection(Str) ->
|
||||
case string:find(Str, "Upgrade") of
|
||||
nomatch -> true;
|
||||
_ -> false
|
||||
end.
|
||||
|
||||
% version must be EXACTLY <<"13">>
|
||||
bad_version(<<"13">> -> false;
|
||||
bad_version(_) -> true.
|
||||
|
||||
|
||||
-spec bad_version(binary()) -> true | false.
|
||||
% @private version must be EXACTLY <<"13">>
|
||||
|
||||
bad_version(<<"13">>) -> false;
|
||||
bad_version(_) -> true.
|
||||
|
||||
|
||||
|
||||
-spec response_token(binary()) -> binary().
|
||||
@ -157,3 +283,204 @@ response_token(ChallengeToken) when is_binary(ChallengeToken) ->
|
||||
ConcatString = <<ChallengeToken/binary, MagicString/binary>>,
|
||||
Sha1 = crypto:hash(sha, ConcatString),
|
||||
base64:encode(Sha1).
|
||||
|
||||
|
||||
|
||||
-type opcode() :: continuation
|
||||
| text
|
||||
| binary
|
||||
| close
|
||||
| ping
|
||||
| pong.
|
||||
|
||||
-record(frame,
|
||||
{fin = none :: none | boolean(),
|
||||
rsv = none :: none | <<_:3>>,
|
||||
opcode = none :: none | opcode(),
|
||||
mask = none :: none | boolean(),
|
||||
payload_length = none :: none | non_neg_integer(),
|
||||
masking_key = none :: none | <<_:256>>,
|
||||
payload = none :: none | binary()}).
|
||||
|
||||
-type frame() :: #frame{}.
|
||||
|
||||
|
||||
|
||||
-spec recv(Socket, Received) -> Result
|
||||
when Socket :: gen_tcp:socket(),
|
||||
Received :: binary(),
|
||||
Result :: {ok, binary()}
|
||||
| {error, Reason}
|
||||
Reason :: any().
|
||||
|
||||
recv(Sock, Recv) ->
|
||||
recv(#frame{}, Sock, Recv).
|
||||
|
||||
|
||||
%% frame: 1 bit
|
||||
recv_frame(Frame = #frame{fin = none}, Sock, <<FinBit:1, Rest/bits>>) ->
|
||||
NewFin =
|
||||
case FinBit of
|
||||
0 -> false;
|
||||
1 -> true
|
||||
end,
|
||||
NewFrame = Frame#frame{fin = NewFin},
|
||||
recv_frame(NewFrame, Sock, Rest);
|
||||
recv_frame(Frame = #frame{fin = none}, Sock, <<>>) ->
|
||||
case inet:setopts(Sock, [{active, once}]) of
|
||||
ok ->
|
||||
receive
|
||||
{tcp, Sock, Bin} -> recv_frame(Frame, Sock, Bin);
|
||||
{tcp_closed, Socket} -> {error, tcp_closed};
|
||||
{tcp_error, Socket, Reason} -> {error, {tcp_error, Reason}}
|
||||
after 3000 ->
|
||||
{error, timeout}
|
||||
end;
|
||||
{error, Reason} ->
|
||||
{error, {inet, Reason}}
|
||||
end;
|
||||
%% rsv: 3 bits
|
||||
recv_frame(Frame = #frame{rsv = none}, Sock, <<RSV:3/bits, Rest/bits>>) ->
|
||||
NewFrame = Frame#frame{rsv = RSV},
|
||||
recv_frame(NewFrame, Sock, Rest);
|
||||
recv_frame(Frame = #frame{rsv = none}, Sock, Received) ->
|
||||
case inet:setopts(Sock, [{active, once}]) of
|
||||
ok ->
|
||||
receive
|
||||
{tcp, Sock, Bin} -> recv_frame(Frame, Sock, <<Received/bits, Bin/binary>>);
|
||||
{tcp_closed, Socket} -> {error, tcp_closed};
|
||||
{tcp_error, Socket, Reason} -> {error, {tcp_error, Reason}}
|
||||
after 3000 ->
|
||||
{error, timeout}
|
||||
end;
|
||||
{error, Reason} ->
|
||||
{error, {inet, Reason}}
|
||||
end;
|
||||
%% opcode
|
||||
recv_frame(Frame = #frame{opcode = none}, Sock, <<OpcodeInt:4, Rest/bits>>) ->
|
||||
if
|
||||
OpcodeInt =:=
|
||||
end;
|
||||
|
||||
|
||||
-spec send(Socket, Payload) -> Result
|
||||
when Socket :: gen_tcp:socket(),
|
||||
Payload :: iodata(),
|
||||
Result :: ok
|
||||
| {error, Reason},
|
||||
Reason :: closed | {timeout, RestData} | inet:posix(),
|
||||
RestData :: binary() | erlang:iovec().
|
||||
% @doc
|
||||
% send binary data over Socket. handles frame nonsense
|
||||
% @end
|
||||
|
||||
send(Socket, Payload) ->
|
||||
BPayload = unicode:characters_to_binary(Payload),
|
||||
Frame = payload_to_frame(BPayload),
|
||||
send_frame(Socket, Frame).
|
||||
|
||||
|
||||
payload_to_frame(Payload) when byte_size(Payload) < (1 bsl 64) ->
|
||||
#frame{fin = true,
|
||||
opcode = binary,
|
||||
mask = false,
|
||||
payload_length = byte_size(Payload),
|
||||
masking_key = none,
|
||||
payload = Payload}.
|
||||
|
||||
|
||||
|
||||
-spec send_frame(Sock, Frame) -> Result
|
||||
when Sock :: gen_tcp:socket(),
|
||||
Frame :: frame(),
|
||||
Result :: ok
|
||||
| {error, Reason},
|
||||
Reason :: closed | {timeout, RestData} | inet:posix(),
|
||||
RestData :: binary() | erlang:iovec().
|
||||
% @private
|
||||
% send a frame on the socket
|
||||
% @end
|
||||
|
||||
send_frame(Sock, Frame) ->
|
||||
Binary = render_frame(Frame),
|
||||
gen_tcp:send(Sock, Binary).
|
||||
|
||||
|
||||
|
||||
-spec render_frame(Frame) -> Binary
|
||||
when Frame :: frame(),
|
||||
Binary :: binary().
|
||||
% @private
|
||||
% render a frame
|
||||
% @end
|
||||
|
||||
render_frame(#frame{fin = Fin,
|
||||
opcode = Opcode,
|
||||
payload_length = Len,
|
||||
payload = Payload}) ->
|
||||
BFin =
|
||||
case Fin of
|
||||
true -> <<1:1>>;
|
||||
false -> <<0:1>>
|
||||
end,
|
||||
BRSV = <<0:3>>,
|
||||
BOpcode =
|
||||
case Opcode of
|
||||
continuation -> << 0:1>>;
|
||||
text -> << 1:1>>;
|
||||
binary -> << 2:1>>;
|
||||
close -> << 8:1>>;
|
||||
ping -> << 9:1>>;
|
||||
pong -> <<10:1>>
|
||||
end,
|
||||
BMask = <<0:1>>,
|
||||
BPayloadLength = render_payload_length(Len),
|
||||
<<BFin/bits,
|
||||
BRSV/bits,
|
||||
BOpcode/bits,
|
||||
BMask/bits,
|
||||
BPayloadLength/bits,
|
||||
Payload/bits>>.
|
||||
|
||||
|
||||
|
||||
-spec render_payload_length(non_neg_integer()) -> binary().
|
||||
% @private
|
||||
% > Payload length: 7 bits, 7+16 bits, or 7+64 bits
|
||||
% >
|
||||
% > The length of the "Payload data", in bytes: if 0-125, that is the
|
||||
% > payload length. If 126, the following 2 bytes interpreted as a
|
||||
% > 16-bit unsigned integer are the payload length. If 127, the
|
||||
% > following 8 bytes interpreted as a 64-bit unsigned integer (the
|
||||
% > most significant bit MUST be 0) are the payload length. Multibyte
|
||||
% > length quantities are expressed in network byte order. Note that
|
||||
% > in all cases, the minimal number of bytes MUST be used to encode
|
||||
% > the length, for example, the length of a 124-byte-long string
|
||||
% > can't be encoded as the sequence 126, 0, 124. The payload length
|
||||
% > is the length of the "Extension data" + the length of the
|
||||
% > "Application data". The length of the "Extension data" may be
|
||||
% > zero, in which case the payload length is the length of the
|
||||
% > "Application data".
|
||||
|
||||
render_payload_length(Len) when 0 =< Len, Len =< 125 ->
|
||||
<<Len:7>>;
|
||||
render_payload_length(Len) when 126 =< Len, Len =< 2#1111_1111_1111_1111 ->
|
||||
<<126:7, Len:16>>;
|
||||
render_payload_length(Len) when (1 bsl 16) =< Len, Len < (1 bsl 63) ->
|
||||
<<127:7, Len:64>>.
|
||||
|
||||
|
||||
|
||||
-spec pong(Socket) -> Result
|
||||
when Socket :: gen_tcp:socket(),
|
||||
Result :: ok
|
||||
| {error, Reason},
|
||||
Reason :: closed | {timeout, RestData} | inet:posix(),
|
||||
RestData :: binary() | erlang:iovec().
|
||||
|
||||
pong(Sock) ->
|
||||
Frame = #frame{fin = true,
|
||||
opcode = pong,
|
||||
payload_length = 0,
|
||||
payload = <<>>},
|
||||
send_frame(Sock, Frame).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user