From bc8ebc7ec63e9363931efccffc858861f3a6e060 Mon Sep 17 00:00:00 2001 From: Hans Svensson Date: Fri, 2 Mar 2018 15:42:43 +0100 Subject: [PATCH] Refactor handshake flow control --- src/enoise.erl | 75 ++++++++++++----------------------------- src/enoise_hs_state.erl | 25 ++++++++------ 2 files changed, 36 insertions(+), 64 deletions(-) diff --git a/src/enoise.erl b/src/enoise.erl index 52ed772..a1c6e67 100644 --- a/src/enoise.erl +++ b/src/enoise.erl @@ -32,35 +32,23 @@ %%==================================================================== %% API functions %%==================================================================== -connect(Address, Port, Options) -> - connect(Address, Port, Options, infinity). +connect(TcpSock, Options) -> + do_handshake(TcpSock, initiator, Options). +accept(TcpSock, Options) -> + do_handshake(TcpSock, responder, Options). -connect(Address, Port, Options, Timeout) -> - case initiate_handshake(initiator, Options) of - {ok, HS} -> - TcpOpts = enoise_opts:tcp_opts(Options), - case gen_tcp:connect(Address, Port, TcpOpts, Timeout) of - {ok, TcpSock} -> - do_handshake(TcpSock, HS, Options); - Err = {error, _Reason} -> - Err - end; - Err = {error, _Reason} -> - Err - end. - -send(E = #enoise{ tcp_sock = TcpSock, rx = RX0 }, Msg0) -> - {ok, RX1, Msg1} = enoise_cipher_state:encrypt_with_ad(RX0, <<>>, Msg0), +send(E = #enoise{ tcp_sock = TcpSock, tx = TX0 }, Msg0) -> + {ok, TX1, Msg1} = enoise_cipher_state:encrypt_with_ad(TX0, <<>>, Msg0), gen_tcp:send(TcpSock, <<(byte_size(Msg1)):16, Msg1/binary>>), - E#enoise{ rx = RX1 }. + E#enoise{ tx = TX1 }. -recv(E = #enoise{ tcp_sock = TcpSock, tx = TX0 }) -> +recv(E = #enoise{ tcp_sock = TcpSock, rx = RX0 }) -> receive {tcp, TcpSock, <>} -> Size = byte_size(Data), - {ok, TX1, Msg1} = enoise_cipher_state:decrypt_with_ad(TX0, <<>>, Data), - {E#enoise{ tx = TX1 }, Msg1} - after 1000 -> error(timeout) end. + {ok, RX1, Msg1} = enoise_cipher_state:decrypt_with_ad(RX0, <<>>, Data), + {E#enoise{ rx = RX1 }, Msg1} + after 5000 -> error(timeout) end. close(#enoise{ tcp_sock = TcpSock }) -> gen_tcp:close(TcpSock). @@ -69,7 +57,7 @@ close(#enoise{ tcp_sock = TcpSock }) -> %%==================================================================== %% Internal functions %%==================================================================== -initiate_handshake(Role, Options) -> +do_handshake(TcpSock, Role, Options) -> Prologue = proplists:get_value(prologue, Options, <<>>), NoiseProtocol = proplists:get_value(noise, Options), @@ -79,41 +67,22 @@ initiate_handshake(Role, Options) -> RE = proplists:get_value(re, Options, undefined), HSState = enoise_hs_state:init(NoiseProtocol, Role, Prologue, {S, E, RS, RE}), - {ok, HSState}. - - -do_handshake(TcpSock, HState, Options) -> - PreComm = proplists:get_value(pre_comm, Options, <<>>), %% TODO: Not standard! - - gen_tcp:send(TcpSock, PreComm), - - do_handshake(TcpSock, HState). - + do_handshake(TcpSock, HSState). do_handshake(TcpSock, HState) -> case enoise_hs_state:next_message(HState) of in -> receive {tcp, TcpSock, Data} -> - case enoise_hs_state:read_message(HState, Data) of - {ok, HState1, _Msg} -> - do_handshake(TcpSock, HState1); - {done, _HState1, _Msg, {C1, C2}} -> - {ok, #enoise{ tcp_sock = TcpSock, rx = C1, tx = C2 }} - end - after 1000 -> - error(timeout) - end; + {ok, HState1, _Msg} = enoise_hs_state:read_message(HState, Data), + do_handshake(TcpSock, HState1) + after 1000 -> error(timeout) end; out -> - case enoise_hs_state:write_message(HState, <<>>) of - {ok, HState1, Msg} -> - io:format("Sending: ~p\n", [add_len(Msg)]), - gen_tcp:send(TcpSock, add_len(Msg)), - do_handshake(TcpSock, HState1); - {done, _HState1, Msg, {C1, C2}} -> - io:format("Sending: ~p\n", [add_len(Msg)]), - gen_tcp:send(TcpSock, add_len(Msg)), - {ok, #enoise{ tcp_sock = TcpSock, rx = C1, tx = C2 }} - end + {ok, HState1, Msg} = enoise_hs_state:write_message(HState, <<>>), + gen_tcp:send(TcpSock, add_len(Msg)), + do_handshake(TcpSock, HState1); + done -> + {ok, #{ rx := Rx, tx := Tx }} = enoise_hs_state:finalize(HState), + {ok, #enoise{ tcp_sock = TcpSock, rx = Rx, tx = Tx }} end. add_len(Msg) -> diff --git a/src/enoise_hs_state.erl b/src/enoise_hs_state.erl index f9a96a1..bc369b2 100644 --- a/src/enoise_hs_state.erl +++ b/src/enoise_hs_state.erl @@ -4,7 +4,7 @@ -module(enoise_hs_state). --export([init/4, next_message/1, read_message/2, write_message/2]). +-export([finalize/1, init/4, next_message/1, read_message/2, write_message/2]). -include("enoise.hrl"). @@ -37,6 +37,15 @@ init(Protocol, Role, Prologue, {S, E, RS, RE}) -> ({in, [e]}, HS0) -> mix_hash(HS0, RE) end, HS, PreMsgs). +finalize(#noise_hs{ msgs = [], ss = SS, role = Role }) -> + {C1, C2} = enoise_sym_state:split(SS), + case Role of + initiator -> {ok, #{ tx => C1, rx => C2 }}; + responder -> {ok, #{ rx => C1, tx => C2 }} + end; +finalize(_) -> + error({bad_state, finalize}). + next_message(#noise_hs{ msgs = [{Dir, _} | _] }) -> Dir; next_message(_) -> done. @@ -44,19 +53,12 @@ write_message(HS = #noise_hs{ msgs = [{out, Msg} | Msgs] }, PayLoad) -> {HS1, MsgBuf1} = write_message(HS#noise_hs{ msgs = Msgs }, Msg, <<>>), {ok, HS2, MsgBuf2} = encrypt_and_hash(HS1, PayLoad), MsgBuf = <>, - case Msgs of - [] -> {done, HS2, MsgBuf, enoise_sym_state:split(HS2#noise_hs.ss)}; - _ -> {ok, HS2, MsgBuf} - end. + {ok, HS2, MsgBuf}. read_message(HS = #noise_hs{ msgs = [{in, Msg} | Msgs] }, <>) -> Size = byte_size(Message), {HS1, RestBuf1} = read_message(HS#noise_hs{ msgs = Msgs }, Msg, Message), - {ok, HS2, PlainBuf} = decrypt_and_hash(HS1, RestBuf1), - case Msgs of - [] -> {done, HS2, PlainBuf, enoise_sym_state:split(HS2#noise_hs.ss)}; - _ -> {ok, HS2, PlainBuf} - end. + decrypt_and_hash(HS1, RestBuf1). write_message(HS, [], MsgBuf) -> {HS, MsgBuf}; @@ -131,6 +133,7 @@ decrypt_and_hash(HS = #noise_hs{ ss = SS0 }, CipherText) -> {ok, SS1, PlainText} = enoise_sym_state:decrypt_and_hash(SS0, CipherText), {ok, HS#noise_hs{ ss = SS1 }, PlainText}. + msgs(Role, Protocol) -> {_Pre, Msgs} = protocol(Protocol), role_adapt(Role, Msgs). @@ -142,7 +145,7 @@ pre_msgs(Role, Protocol) -> role_adapt(initiator, Msgs) -> Msgs; role_adapt(responder, Msgs) -> - Flip = fun(in) -> out; (out) -> in end, + Flip = fun({in, Msg}) -> {out, Msg}; ({out, Msg}) -> {in, Msg} end, lists:map(Flip, Msgs). protocol(nn) ->