diff --git a/src/enoise.erl b/src/enoise.erl index 1156d97..2c494ed 100644 --- a/src/enoise.erl +++ b/src/enoise.erl @@ -15,7 +15,7 @@ -module(enoise). %% Main function with generic Noise handshake --export([handshake/3]). +-export([handshake/2, handshake/3, step_handshake/2]). %% API exports - Mainly mimicing gen_tcp -export([ accept/2 @@ -60,26 +60,18 @@ binary(). %% API functions %%==================================================================== +%% @doc Start an interactive handshake +handshake(Options, Role) -> + HState = create_hstate(Options, Role), + step_handshake(HState, <<>>). + +step_handshake(HState, Data) -> + do_step_handshake(HState, Data). + %% @doc The main function - performs a Noise handshake handshake(Options, Role, ComState) -> - Prologue = proplists:get_value(prologue, Options, <<>>), - NoiseProtocol0 = proplists:get_value(noise, Options), - - NoiseProtocol = - case NoiseProtocol0 of - X when is_binary(X); is_list(X) -> - enoise_protocol:from_name(X); - _ -> NoiseProtocol0 - end, - - S = proplists:get_value(s, Options, undefined), - E = proplists:get_value(e, Options, undefined), - RS = proplists:get_value(rs, Options, undefined), - RE = proplists:get_value(re, Options, undefined), - - HSState = enoise_hs_state:init(NoiseProtocol, Role, - Prologue, {S, E, RS, RE}), - do_handshake(HSState, ComState). + HState = create_hstate(Options, Role), + do_handshake(HState, ComState). %% @doc Upgrades a gen_tcp, or equivalent, connected socket to a Noise socket, @@ -151,25 +143,6 @@ controlling_process(#enoise{ pid = Pid }, NewPid) -> %%==================================================================== %% Internal functions %%==================================================================== -tcp_handshake(TcpSock, Role, Options) -> - case check_gen_tcp(TcpSock) of - ok -> - {ok, [{active, Active}]} = inet:getopts(TcpSock, [active]), - ComState = #{ recv_msg => fun gen_tcp_rcv_msg/1, - send_msg => fun gen_tcp_snd_msg/2, - state => {TcpSock, Active, <<>>} }, - - case handshake(Options, Role, ComState) of - {ok, #{ rx := Rx, tx := Tx }, #{ state := {_, _, Buf} }} -> - {ok, Pid} = enoise_connection:start_link(TcpSock, Rx, Tx, self(), {Active, Buf}), - {ok, #enoise{ pid = Pid }}; - Err = {error, _} -> - Err - end; - Err = {error, _} -> - Err - end. - do_handshake(HState, ComState) -> case enoise_hs_state:next_message(HState) of in -> @@ -201,8 +174,58 @@ hs_send_msg(CS = #{ send_msg := Send, state := S }, Data) -> Err = {error, _} -> Err end. +do_step_handshake(HState, Data) -> + case enoise_hs_state:next_message(HState) of + in when Data == <<>> -> + {in, HState}; + in -> + {ok, HState1, _Msg} = enoise_hs_state:read_message(HState, Data), %% TODO: error handling + do_step_handshake(HState1, <<>>); + out -> + {ok, HState1, Msg} = enoise_hs_state:write_message(HState, <<>>), + {out, Msg, HState1}; + done -> + {done, enoise_hs_state:finalize(HState)} + end. %% -- gen_tcp specific functions --------------------------------------------- +tcp_handshake(TcpSock, Role, Options) -> + case check_gen_tcp(TcpSock) of + ok -> + {ok, [{active, Active}]} = inet:getopts(TcpSock, [active]), + ComState = #{ recv_msg => fun gen_tcp_rcv_msg/1, + send_msg => fun gen_tcp_snd_msg/2, + state => {TcpSock, Active, <<>>} }, + + case handshake(Options, Role, ComState) of + {ok, #{ rx := Rx, tx := Tx }, #{ state := {_, _, Buf} }} -> + {ok, Pid} = enoise_connection:start_link(TcpSock, Rx, Tx, self(), {Active, Buf}), + {ok, #enoise{ pid = Pid }}; + Err = {error, _} -> + Err + end; + Err = {error, _} -> + Err + end. + +create_hstate(Options, Role) -> + Prologue = proplists:get_value(prologue, Options, <<>>), + NoiseProtocol0 = proplists:get_value(noise, Options), + + NoiseProtocol = + case NoiseProtocol0 of + X when is_binary(X); is_list(X) -> + enoise_protocol:from_name(X); + _ -> NoiseProtocol0 + end, + + S = proplists:get_value(s, Options, undefined), + E = proplists:get_value(e, Options, undefined), + RS = proplists:get_value(rs, Options, undefined), + RE = proplists:get_value(re, Options, undefined), + + enoise_hs_state:init(NoiseProtocol, Role, + Prologue, {S, E, RS, RE}). check_gen_tcp(TcpSock) -> {ok, TcpOpts} = inet:getopts(TcpSock, [mode, packet, header, packet_size]),