Introduce generic enoise:handshake and implement tcp in terms of it

This commit is contained in:
Hans Svensson
2018-03-09 10:26:49 +01:00
parent 7805b983b2
commit 2f220d599c
2 changed files with 126 additions and 54 deletions
+94 -43
View File
@@ -4,15 +4,19 @@
%%% @doc Module is an interface to the Noise protocol
%%% [https://noiseprotocol.org]
%%%
%%% The module implements Noise over TCP (i.e. `gen_tcp') and after "upgrading"
%%% a `gen_tcp'-socket into a `enoise'-socket it has a similar API as
%%% `gen_tcp'.
%%% The module implements Noise handshake in `handshake/3'.
%%%
%%% @end
%%% ------------------------------------------------------------------
%%% For convenience there is also an API to use Noise over TCP (i.e. `gen_tcp')
%%% and after "upgrading" a `gen_tcp'-socket into a `enoise'-socket it has a
%%% similar API as `gen_tcp'.
%%%
%%% @end ------------------------------------------------------------------
-module(enoise).
%% Main function with generic Noise handshake
-export([handshake/3]).
%% API exports - Mainly mimicing gen_tcp
-export([ accept/2
, close/1
@@ -56,6 +60,28 @@ binary().
%% API functions
%%====================================================================
%% @doc The main function - performs a Noise handshake
handshake(Options, Role, ComState) ->
Prologue = proplists:get_value(prologue, Options, <<>>),
NoiseProtocol0 = proplists:get_value(noise, Options),
NoiseProtocol =
case NoiseProtocol0 of
X when is_binary(X); is_list(X) ->
enoise_protocol:from_name(X);
_ -> NoiseProtocol0
end,
S = proplists:get_value(s, Options, undefined),
E = proplists:get_value(e, Options, undefined),
RS = proplists:get_value(rs, Options, undefined),
RE = proplists:get_value(re, Options, undefined),
HSState = enoise_hs_state:init(NoiseProtocol, Role,
Prologue, {S, E, RS, RE}),
do_handshake(HSState, ComState).
%% @doc Upgrades a gen_tcp, or equivalent, connected socket to a Noise socket,
%% that is, performs the client-side noise handshake.
%%
@@ -65,7 +91,7 @@ binary().
Options :: noise_options()) ->
{ok, noise_socket()} | {error, term()}.
connect(TcpSock, Options) ->
start_handshake(TcpSock, initiator, Options).
tcp_handshake(TcpSock, initiator, Options).
%% @doc Upgrades a gen_tcp, or equivalent, connected socket to a Noise socket,
%% that is, performs the server-side noise handshake.
@@ -76,7 +102,7 @@ connect(TcpSock, Options) ->
Options :: noise_options()) ->
{ok, noise_socket()} | {error, term()}.
accept(TcpSock, Options) ->
start_handshake(TcpSock, responder, Options).
tcp_handshake(TcpSock, responder, Options).
%% @doc Writes `Data' to `Socket'
%% @end
@@ -125,69 +151,94 @@ controlling_process(#enoise{ pid = Pid }, NewPid) ->
%%====================================================================
%% Internal functions
%%====================================================================
start_handshake(TcpSock, Role, Options) ->
case check_tcp(TcpSock) of
{ok, WasActive} ->
inet:setopts(TcpSock, [{active, false}]), %% False for handshake
Prologue = proplists:get_value(prologue, Options, <<>>),
NoiseProtocol = proplists:get_value(noise, Options),
tcp_handshake(TcpSock, Role, Options) ->
case check_gen_tcp(TcpSock) of
ok ->
{ok, [{active, Active}]} = inet:getopts(TcpSock, [active]),
ComState = #{ recv_msg => fun gen_tcp_rcv_msg/1,
send_msg => fun gen_tcp_snd_msg/2,
state => {TcpSock, Active, <<>>} },
S = proplists:get_value(s, Options, undefined),
E = proplists:get_value(e, Options, undefined),
RS = proplists:get_value(rs, Options, undefined),
RE = proplists:get_value(re, Options, undefined),
HSState = enoise_hs_state:init(NoiseProtocol, Role,
Prologue, {S, E, RS, RE}),
do_handshake(TcpSock, HSState, WasActive);
case handshake(Options, Role, ComState) of
{ok, #{ rx := Rx, tx := Tx }, #{ state := {_, _, Buf} }} ->
{ok, Pid} = enoise_connection:start_link(TcpSock, Rx, Tx, self(), {Active, Buf}),
{ok, #enoise{ pid = Pid }};
Err = {error, _} ->
Err
end;
Err = {error, _} ->
Err
end.
do_handshake(TcpSock, HState, WasActive) ->
do_handshake(HState, ComState) ->
case enoise_hs_state:next_message(HState) of
in ->
case hs_recv(TcpSock) of
{ok, Data} ->
case hs_recv_msg(ComState) of
{ok, Data, ComState1} ->
{ok, HState1, _Msg} = enoise_hs_state:read_message(HState, Data),
do_handshake(TcpSock, HState1, WasActive);
do_handshake(HState1, ComState1);
Err = {error, _} ->
Err
end;
out ->
{ok, HState1, Msg} = enoise_hs_state:write_message(HState, <<>>),
hs_send(TcpSock, Msg),
do_handshake(TcpSock, HState1, WasActive);
{ok, ComState1} = hs_send_msg(ComState, Msg),
do_handshake(HState1, ComState1);
done ->
{ok, #{ rx := Rx, tx := Tx }} = enoise_hs_state:finalize(HState),
{ok, Pid} = enoise_connection:start_link(TcpSock, Rx, Tx, self(), WasActive),
{ok, #enoise{ pid = Pid }}
{ok, Res} = enoise_hs_state:finalize(HState),
{ok, Res, ComState}
end.
check_tcp(TcpSock) ->
{ok, TcpOpts} = inet:getopts(TcpSock, [mode, packet, active, header, packet_size]),
hs_recv_msg(CS = #{ recv_msg := Recv, state := S }) ->
case Recv(S) of
{ok, Data, S1} -> {ok, Data, CS#{ state := S1 }};
Err = {error, _} -> Err
end.
hs_send_msg(CS = #{ send_msg := Send, state := S }, Data) ->
case Send(S, Data) of
{ok, S1} -> {ok, CS#{ state := S1 }};
Err = {error, _} -> Err
end.
%% -- gen_tcp specific functions ---------------------------------------------
check_gen_tcp(TcpSock) ->
{ok, TcpOpts} = inet:getopts(TcpSock, [mode, packet, header, packet_size]),
Packet = proplists:get_value(packet, TcpOpts, 0),
Header = proplists:get_value(header, TcpOpts, 0),
Active = proplists:get_value(active, TcpOpts, true),
PSize = proplists:get_value(packet_size, TcpOpts, undefined),
Mode = proplists:get_value(mode, TcpOpts, binary),
case (Packet == 0 orelse Packet == raw)
andalso Header == 0 andalso PSize == 0 andalso Mode == binary of
true ->
case gen_tcp:controlling_process(TcpSock, self()) of
ok -> {ok, Active};
Err = {error, _} -> Err
end;
gen_tcp:controlling_process(TcpSock, self());
false ->
{error, {invalid_tcp_options, proplists:delete(active, TcpOpts)}}
{error, {invalid_tcp_options, TcpOpts}}
end.
hs_send(TcpSock, Msg) ->
gen_tcp_snd_msg(S = {TcpSock, _, _}, Msg) ->
Len = byte_size(Msg),
gen_tcp:send(TcpSock, <<Len:16, Msg/binary>>).
ok = gen_tcp:send(TcpSock, <<Len:16, Msg/binary>>),
{ok, S}.
hs_recv(TcpSock) ->
gen_tcp_rcv_msg({TcpSock, true, Buf}) ->
receive {tcp, TcpSock, Data} ->
case <<Buf/binary, Data/binary>> of
Buf1 = <<Len:16, Rest/binary>> when byte_size(Rest) < Len ->
gen_tcp_rcv_msg({TcpSock, true, Buf1});
<<Len:16, Rest/binary>> ->
<<Data1:Len/binary, Buf1/binary>> = Rest,
{ok, Data1, {TcpSock, true, Buf1}}
end
after 1000 ->
{error, timeout}
end;
gen_tcp_rcv_msg(S = {TcpSock, false, <<>>}) ->
{ok, <<Len:16>>} = gen_tcp:recv(TcpSock, 2, 1000),
gen_tcp:recv(TcpSock, Len, 1000).
case gen_tcp:recv(TcpSock, Len, 1000) of
{ok, Data} -> {ok, Data, S};
Err = {error, _} -> Err
end.