From 2f220d599c12b1c62fe5bddc7bc60c9b21bee5db Mon Sep 17 00:00:00 2001 From: Hans Svensson Date: Fri, 9 Mar 2018 10:26:49 +0100 Subject: [PATCH] Introduce generic enoise:handshake and implement tcp in terms of it --- src/enoise.erl | 137 ++++++++++++++++++++++++++------------ src/enoise_connection.erl | 43 +++++++++--- 2 files changed, 126 insertions(+), 54 deletions(-) diff --git a/src/enoise.erl b/src/enoise.erl index e984258..1156d97 100644 --- a/src/enoise.erl +++ b/src/enoise.erl @@ -4,15 +4,19 @@ %%% @doc Module is an interface to the Noise protocol %%% [https://noiseprotocol.org] %%% -%%% The module implements Noise over TCP (i.e. `gen_tcp') and after "upgrading" -%%% a `gen_tcp'-socket into a `enoise'-socket it has a similar API as -%%% `gen_tcp'. +%%% The module implements Noise handshake in `handshake/3'. %%% -%%% @end -%%% ------------------------------------------------------------------ +%%% For convenience there is also an API to use Noise over TCP (i.e. `gen_tcp') +%%% and after "upgrading" a `gen_tcp'-socket into a `enoise'-socket it has a +%%% similar API as `gen_tcp'. +%%% +%%% @end ------------------------------------------------------------------ -module(enoise). +%% Main function with generic Noise handshake +-export([handshake/3]). + %% API exports - Mainly mimicing gen_tcp -export([ accept/2 , close/1 @@ -56,6 +60,28 @@ binary(). %% API functions %%==================================================================== +%% @doc The main function - performs a Noise handshake +handshake(Options, Role, ComState) -> + Prologue = proplists:get_value(prologue, Options, <<>>), + NoiseProtocol0 = proplists:get_value(noise, Options), + + NoiseProtocol = + case NoiseProtocol0 of + X when is_binary(X); is_list(X) -> + enoise_protocol:from_name(X); + _ -> NoiseProtocol0 + end, + + S = proplists:get_value(s, Options, undefined), + E = proplists:get_value(e, Options, undefined), + RS = proplists:get_value(rs, Options, undefined), + RE = proplists:get_value(re, Options, undefined), + + HSState = enoise_hs_state:init(NoiseProtocol, Role, + Prologue, {S, E, RS, RE}), + do_handshake(HSState, ComState). + + %% @doc Upgrades a gen_tcp, or equivalent, connected socket to a Noise socket, %% that is, performs the client-side noise handshake. %% @@ -65,7 +91,7 @@ binary(). Options :: noise_options()) -> {ok, noise_socket()} | {error, term()}. connect(TcpSock, Options) -> - start_handshake(TcpSock, initiator, Options). + tcp_handshake(TcpSock, initiator, Options). %% @doc Upgrades a gen_tcp, or equivalent, connected socket to a Noise socket, %% that is, performs the server-side noise handshake. @@ -76,7 +102,7 @@ connect(TcpSock, Options) -> Options :: noise_options()) -> {ok, noise_socket()} | {error, term()}. accept(TcpSock, Options) -> - start_handshake(TcpSock, responder, Options). + tcp_handshake(TcpSock, responder, Options). %% @doc Writes `Data' to `Socket' %% @end @@ -125,69 +151,94 @@ controlling_process(#enoise{ pid = Pid }, NewPid) -> %%==================================================================== %% Internal functions %%==================================================================== -start_handshake(TcpSock, Role, Options) -> - case check_tcp(TcpSock) of - {ok, WasActive} -> - inet:setopts(TcpSock, [{active, false}]), %% False for handshake - Prologue = proplists:get_value(prologue, Options, <<>>), - NoiseProtocol = proplists:get_value(noise, Options), +tcp_handshake(TcpSock, Role, Options) -> + case check_gen_tcp(TcpSock) of + ok -> + {ok, [{active, Active}]} = inet:getopts(TcpSock, [active]), + ComState = #{ recv_msg => fun gen_tcp_rcv_msg/1, + send_msg => fun gen_tcp_snd_msg/2, + state => {TcpSock, Active, <<>>} }, - S = proplists:get_value(s, Options, undefined), - E = proplists:get_value(e, Options, undefined), - RS = proplists:get_value(rs, Options, undefined), - RE = proplists:get_value(re, Options, undefined), - - HSState = enoise_hs_state:init(NoiseProtocol, Role, - Prologue, {S, E, RS, RE}), - - do_handshake(TcpSock, HSState, WasActive); + case handshake(Options, Role, ComState) of + {ok, #{ rx := Rx, tx := Tx }, #{ state := {_, _, Buf} }} -> + {ok, Pid} = enoise_connection:start_link(TcpSock, Rx, Tx, self(), {Active, Buf}), + {ok, #enoise{ pid = Pid }}; + Err = {error, _} -> + Err + end; Err = {error, _} -> Err end. -do_handshake(TcpSock, HState, WasActive) -> +do_handshake(HState, ComState) -> case enoise_hs_state:next_message(HState) of in -> - case hs_recv(TcpSock) of - {ok, Data} -> + case hs_recv_msg(ComState) of + {ok, Data, ComState1} -> {ok, HState1, _Msg} = enoise_hs_state:read_message(HState, Data), - do_handshake(TcpSock, HState1, WasActive); + do_handshake(HState1, ComState1); Err = {error, _} -> Err end; out -> {ok, HState1, Msg} = enoise_hs_state:write_message(HState, <<>>), - hs_send(TcpSock, Msg), - do_handshake(TcpSock, HState1, WasActive); + {ok, ComState1} = hs_send_msg(ComState, Msg), + do_handshake(HState1, ComState1); done -> - {ok, #{ rx := Rx, tx := Tx }} = enoise_hs_state:finalize(HState), - {ok, Pid} = enoise_connection:start_link(TcpSock, Rx, Tx, self(), WasActive), - {ok, #enoise{ pid = Pid }} + {ok, Res} = enoise_hs_state:finalize(HState), + {ok, Res, ComState} end. -check_tcp(TcpSock) -> - {ok, TcpOpts} = inet:getopts(TcpSock, [mode, packet, active, header, packet_size]), +hs_recv_msg(CS = #{ recv_msg := Recv, state := S }) -> + case Recv(S) of + {ok, Data, S1} -> {ok, Data, CS#{ state := S1 }}; + Err = {error, _} -> Err + end. + +hs_send_msg(CS = #{ send_msg := Send, state := S }, Data) -> + case Send(S, Data) of + {ok, S1} -> {ok, CS#{ state := S1 }}; + Err = {error, _} -> Err + end. + + +%% -- gen_tcp specific functions --------------------------------------------- + +check_gen_tcp(TcpSock) -> + {ok, TcpOpts} = inet:getopts(TcpSock, [mode, packet, header, packet_size]), Packet = proplists:get_value(packet, TcpOpts, 0), Header = proplists:get_value(header, TcpOpts, 0), - Active = proplists:get_value(active, TcpOpts, true), PSize = proplists:get_value(packet_size, TcpOpts, undefined), Mode = proplists:get_value(mode, TcpOpts, binary), case (Packet == 0 orelse Packet == raw) andalso Header == 0 andalso PSize == 0 andalso Mode == binary of true -> - case gen_tcp:controlling_process(TcpSock, self()) of - ok -> {ok, Active}; - Err = {error, _} -> Err - end; + gen_tcp:controlling_process(TcpSock, self()); false -> - {error, {invalid_tcp_options, proplists:delete(active, TcpOpts)}} + {error, {invalid_tcp_options, TcpOpts}} end. -hs_send(TcpSock, Msg) -> +gen_tcp_snd_msg(S = {TcpSock, _, _}, Msg) -> Len = byte_size(Msg), - gen_tcp:send(TcpSock, <>). + ok = gen_tcp:send(TcpSock, <>), + {ok, S}. -hs_recv(TcpSock) -> +gen_tcp_rcv_msg({TcpSock, true, Buf}) -> + receive {tcp, TcpSock, Data} -> + case <> of + Buf1 = <> when byte_size(Rest) < Len -> + gen_tcp_rcv_msg({TcpSock, true, Buf1}); + <> -> + <> = Rest, + {ok, Data1, {TcpSock, true, Buf1}} + end + after 1000 -> + {error, timeout} + end; +gen_tcp_rcv_msg(S = {TcpSock, false, <<>>}) -> {ok, <>} = gen_tcp:recv(TcpSock, 2, 1000), - gen_tcp:recv(TcpSock, Len, 1000). + case gen_tcp:recv(TcpSock, Len, 1000) of + {ok, Data} -> {ok, Data, S}; + Err = {error, _} -> Err + end. diff --git a/src/enoise_connection.erl b/src/enoise_connection.erl index 8062592..b61390b 100644 --- a/src/enoise_connection.erl +++ b/src/enoise_connection.erl @@ -2,7 +2,10 @@ %%% @copyright 2018, Aeternity Anstalt %%% %%% @doc Module implementing a gen_server for holding a handshaked -%%% Noise connection. +%%% Noise connection over gen_tcp. +%%% +%%% Some care is needed since the underlying transmission is broken up +%%% into Noise packets, so we need some buffering. %%% %%% @end %%% ------------------------------------------------------------------ @@ -25,13 +28,20 @@ -record(state, {rx, tx, owner, tcp_sock, active, buf = <<>>, rawbuf = <<>>}). %% -- API -------------------------------------------------------------------- -start_link(TcpSock, Rx, Tx, Owner, Active) -> - inet:setopts(TcpSock, [{active, Active}]), - State = #state{ rx = Rx, tx = Tx, owner = Owner, - tcp_sock = TcpSock, active = Active }, +start_link(TcpSock, Rx, Tx, Owner, {Active, Buf}) -> + State0 = #state{ rx = Rx, tx = Tx, owner = Owner, + tcp_sock = TcpSock, active = Active }, + State = case Active of + true -> State0; + false -> State0#state{ rawbuf = Buf } + end, case gen_server:start_link(?MODULE, [State], []) of {ok, Pid} -> ok = gen_tcp:controlling_process(TcpSock, Pid), + %% Changing controlling process if active requires a bit + %% of fiddling with already received content... + [ Pid ! {tcp, TcpSock, Buf} || Buf /= <<>>, Active ], + flush_tcp(Active, Pid, TcpSock), {ok, Pid}; Err = {error, _} -> Err @@ -55,6 +65,10 @@ controlling_process(Noise, NewPid) -> init([State]) -> {ok, State}. +handle_call(close, _From, S) -> + {stop, normal, ok, S}; +handle_call(_Call, _From, S = #state{ tcp_sock = closed }) -> + {reply, {error, closed}, S}; handle_call({send, Data}, _From, S) -> {Res, S1} = handle_send(S, Data), {reply, Res, S1}; @@ -65,9 +79,7 @@ handle_call({recv, Length, Timeout}, _From, S) -> {reply, Res, S1}; handle_call({controlling_process, OldPid, NewPid}, _From, S) -> {Res, S1} = handle_control_change(S, OldPid, NewPid), - {reply, Res, S1}; -handle_call(close, _From, S) -> - {stop, normal, ok, S}. + {reply, Res, S1}. handle_cast(_Msg, S) -> {noreply, S}. @@ -78,13 +90,13 @@ handle_info({tcp, TS, Data}, S = #state{ tcp_sock = TS }) -> {noreply, S2}; handle_info({tcp_closed, TS}, S = #state{ tcp_sock = TS, active = A, owner = O }) -> [ O ! {tcp_closed, TS} || A ], - {stop, normal, S#state{ tcp_sock = undefined }}; + {noreply, S#state{ tcp_sock = closed }}; handle_info(Msg, S) -> io:format("Unexpected info: ~p\n", [Msg]), {noreply, S}. terminate(_Reason, #state{ tcp_sock = TcpSock }) -> - [ gen_tcp:close(TcpSock) || TcpSock /= undefined ], + [ gen_tcp:close(TcpSock) || TcpSock /= closed ], ok. code_change(_OldVsn, State, _Extra) -> @@ -102,7 +114,7 @@ handle_control_change(S, _OldPid, _NewPid) -> handle_data(S = #state{ rawbuf = Buf, rx = Rx }, Data) -> case <> of - B = <> when Len < byte_size(Rest) -> + B = <> when Len > byte_size(Rest) -> {S#state{ rawbuf = B }, []}; %% Not a full message - save it <> -> <> = Rest, @@ -211,3 +223,12 @@ timed_recv(TcpSock, Len, TO) -> Err = {error, _} -> Err end. + +flush_tcp(false, _Pid, _TcpSock) -> + ok; +flush_tcp(true, Pid, TcpSock) -> + receive {tcp, TcpSock, Data} -> + Pid ! {tcp, TcpSock, Data}, + flush_tcp(true, Pid, TcpSock) + after 1 -> ok + end.