diff --git a/apps/emqx_stomp/src/emqx_stomp_connection.erl b/apps/emqx_stomp/src/emqx_stomp_connection.erl index b6c05ac68..dc1977944 100644 --- a/apps/emqx_stomp/src/emqx_stomp_connection.erl +++ b/apps/emqx_stomp/src/emqx_stomp_connection.erl @@ -33,8 +33,11 @@ , terminate/2 ]). +%% for protocol +-export([send/4, heartbeat/2]). + -record(stomp_client, {transport, socket, peername, conn_name, conn_state, - await_recv, rate_limit, parse_fun, proto_state, + await_recv, rate_limit, parser, proto_state, proto_env, heartbeat}). -define(INFO_KEYS, [peername, await_recv, conn_state]). @@ -55,9 +58,12 @@ init([Transport, Sock, ProtoEnv]) -> {ok, NewSock} -> {ok, Peername} = Transport:ensure_ok_or_exit(peername, [NewSock]), ConnName = esockd:format(Peername), - SendFun = send_fun(Transport, Sock), - ParseFun = emqx_stomp_frame:parser(ProtoEnv), - ProtoState = emqx_stomp_protocol:init(Peername, SendFun, ProtoEnv), + SendFun = {fun ?MODULE:send/4, [Transport, Sock, self()]}, + HrtBtFun = {fun ?MODULE:heartbeat/2, [Transport, Sock]}, + Parser = emqx_stomp_frame:init_parer_state(ProtoEnv), + ProtoState = emqx_stomp_protocol:init(#{peername => Peername, + sendfun => SendFun, + heartfun => HrtBtFun}, ProtoEnv), RateLimit = init_rate_limit(proplists:get_value(rate_limit, ProtoEnv)), State = run_socket(#stomp_client{transport = Transport, socket = NewSock, @@ -66,7 +72,7 @@ init([Transport, Sock, ProtoEnv]) -> conn_state = running, await_recv = false, rate_limit = RateLimit, - parse_fun = ParseFun, + parser = Parser, proto_env = ProtoEnv, proto_state = ProtoState}), gen_server:enter_loop(?MODULE, [{hibernate_after, 5000}], State, 20000); @@ -79,17 +85,17 @@ init_rate_limit(undefined) -> init_rate_limit({Rate, Burst}) -> esockd_rate_limit:new(Rate, Burst). -send_fun(Transport, Sock) -> - Self = self(), - fun(Data) -> - try Transport:async_send(Sock, Data) of - ok -> ok; - {error, Reason} -> Self ! {shutdown, Reason} - catch - error:Error -> Self ! {shutdown, Error} - end +send(Data, Transport, Sock, ConnPid) -> + try Transport:async_send(Sock, Data) of + ok -> ok; + {error, Reason} -> ConnPid ! {shutdown, Reason} + catch + error:Error -> ConnPid ! {shutdown, Error} end. +heartbeat(Transport, Sock) -> + Transport:send(Sock, <<$\n>>). + handle_call(info, _From, State = #stomp_client{transport = Transport, socket = Sock, peername = Peername, @@ -124,16 +130,19 @@ handle_info({transaction, {timeout, Id}}, State) -> emqx_stomp_transaction:timeout(Id), noreply(State); -handle_info({heartbeat, start, {Cx, Cy}}, - State = #stomp_client{transport = Transport, socket = Sock}) -> - Self = self(), - Incomming = {Cx, statfun(recv_oct, State), fun() -> Self ! {heartbeat, timeout} end}, - Outgoing = {Cy, statfun(send_oct, State), fun() -> Transport:send(Sock, <<$\n>>) end}, - {ok, HbProc} = emqx_stomp_heartbeat:start_link(Incomming, Outgoing), - noreply(State#stomp_client{heartbeat = HbProc}); +handle_info({timeout, TRef, TMsg}, State) when TMsg =:= incoming; + TMsg =:= outgoing -> -handle_info({heartbeat, timeout}, State) -> - stop({shutdown, heartbeat_timeout}, State); + Stat = case TMsg of + incoming -> recv_oct; + _ -> send_oct + end, + case getstat(Stat, State) of + {ok, Val} -> + with_proto(timeout, [TRef, {TMsg, Val}], State); + {error, Reason} -> + shutdown({sock_error, Reason}, State) + end; handle_info({'EXIT', HbProc, Error}, State = #stomp_client{heartbeat = HbProc}) -> stop(Error, State); @@ -186,14 +195,24 @@ code_change(_OldVsn, State, _Extra) -> %% Receive and Parse data %%-------------------------------------------------------------------- +with_proto(Fun, Args, State = #stomp_client{proto_state = ProtoState}) -> + case erlang:apply(emqx_stomp_protocol, Fun, Args ++ [ProtoState]) of + {ok, NProtoState} -> + noreply(State#stomp_client{proto_state = NProtoState}); + {F, Reason, NProtoState} when F == stop; + F == error; + F == shutdown -> + shutdown(Reason, State#stomp_client{proto_state = NProtoState}) + end. + received(<<>>, State) -> noreply(State); -received(Bytes, State = #stomp_client{parse_fun = ParseFun, +received(Bytes, State = #stomp_client{parser = Parser, proto_state = ProtoState}) -> - try ParseFun(Bytes) of - {more, NewParseFun} -> - noreply(State#stomp_client{parse_fun = NewParseFun}); + try emqx_stomp_frame:parse(Bytes, Parser) of + {more, NewParser} -> + noreply(State#stomp_client{parser = NewParser}); {ok, Frame, Rest} -> ?LOG(info, "RECV Frame: ~s", [emqx_stomp_frame:format(Frame)], State), case emqx_stomp_protocol:received(Frame, ProtoState) of @@ -216,7 +235,7 @@ received(Bytes, State = #stomp_client{parse_fun = ParseFun, end. reset_parser(State = #stomp_client{proto_env = ProtoEnv}) -> - State#stomp_client{parse_fun = emqx_stomp_frame:parser(ProtoEnv)}. + State#stomp_client{parser = emqx_stomp_frame:init_parer_state(ProtoEnv)}. rate_limit(_Size, State = #stomp_client{rate_limit = undefined}) -> run_socket(State); @@ -238,12 +257,10 @@ run_socket(State = #stomp_client{transport = Transport, socket = Sock}) -> Transport:async_recv(Sock, 0, infinity), State#stomp_client{await_recv = true}. -statfun(Stat, #stomp_client{transport = Transport, socket = Sock}) -> - fun() -> - case Transport:getstat(Sock, [Stat]) of - {ok, [{Stat, Val}]} -> {ok, Val}; - {error, Error} -> {error, Error} - end +getstat(Stat, #stomp_client{transport = Transport, socket = Sock}) -> + case Transport:getstat(Sock, [Stat]) of + {ok, [{Stat, Val}]} -> {ok, Val}; + {error, Error} -> {error, Error} end. noreply(State) -> diff --git a/apps/emqx_stomp/src/emqx_stomp_frame.erl b/apps/emqx_stomp/src/emqx_stomp_frame.erl index e5795335a..5d5ddffe1 100644 --- a/apps/emqx_stomp/src/emqx_stomp_frame.erl +++ b/apps/emqx_stomp/src/emqx_stomp_frame.erl @@ -70,7 +70,8 @@ -include("emqx_stomp.hrl"). --export([ parser/1 +-export([ init_parer_state/1 + , parse/2 , serialize/1 ]). @@ -91,14 +92,18 @@ -record(frame_limit, {max_header_num, max_header_length, max_body_length}). --type(parser() :: fun((binary()) -> {ok, stomp_frame(), binary()} - | {more, parser()} - | {error, any()})). +-type(result() :: {ok, stomp_frame(), binary()} + | {more, parser()} + | {error, any()}). + +-type(parser() :: #{phase := none | command | headers | hdname | hdvalue | body, + pre => binary(), + state := #parser_state{}}). %% @doc Initialize a parser --spec parser([proplists:property()]) -> parser(). -parser(Opts) -> - fun(Bin) -> parse(none, Bin, #parser_state{limit = limit(Opts)}) end. +-spec init_parer_state([proplists:property()]) -> parser(). +init_parer_state(Opts) -> + #{phase => none, state => #parser_state{limit = limit(Opts)}}. limit(Opts) -> #frame_limit{max_header_num = g(max_header_num, Opts, ?MAX_HEADER_NUM), @@ -108,29 +113,31 @@ limit(Opts) -> g(Key, Opts, Val) -> proplists:get_value(Key, Opts, Val). -%% @doc Parse frame --spec(parse(Phase :: atom(), binary(), #parser_state{}) -> - {ok, stomp_frame(), binary()} | {more, parser()} | {error, any()}). -parse(none, <<>>, State) -> - {more, fun(Bin) -> parse(none, Bin, State) end}; -parse(none, <>, State) -> - parse(none, Bin, State); -parse(none, Bin, State) -> - parse(command, Bin, State); +-spec parse(binary(), parser()) -> result(). +parse(<<>>, Parser) -> + {more, Parser}; -parse(Phase, <<>>, State) -> - {more, fun(Bin) -> parse(Phase, Bin, State) end}; -parse(Phase, <>, State) -> - {more, fun(Bin) -> parse(Phase, <>, State) end}; -parse(Phase, <>, State) -> +parse(Bytes, #{phase := body, len := Len, state := State}) -> + parse(body, Bytes, State, Len); + +parse(Bytes, Parser = #{pre := Pre}) -> + parse(<
>, maps:without([pre], Parser));
+parse(<>, #{phase := Phase, state := State}) ->
parse(Phase, <>, State);
-parse(_Phase, <>, _State) ->
+parse(<>, Parser) ->
+ {more, Parser#{pre => <>}};
+parse(<>, _Parser) ->
{error, linefeed_expected};
-parse(Phase, <>, State) when Phase =:= hdname; Phase =:= hdvalue ->
- {more, fun(Bin) -> parse(Phase, <>, State) end};
-parse(Phase, <>, State) when Phase =:= hdname; Phase =:= hdvalue ->
+
+parse(<>, Parser = #{phase := Phase}) when Phase =:= hdname; Phase =:= hdvalue ->
+ {more, Parser#{pre => <>}};
+parse(<>, #{phase := Phase, state := State}) when Phase =:= hdname; Phase =:= hdvalue ->
parse(Phase, Rest, acc(unescape(Ch), State));
+parse(Bytes, #{phase := none, state := State}) ->
+ parse(command, Bytes, State).
+
+%% @private
parse(command, <>, State = #parser_state{acc = Acc}) ->
parse(headers, Rest, State#parser_state{cmd = Acc, acc = <<>>});
parse(command, <>, State) ->
@@ -153,20 +160,21 @@ parse(hdvalue, <>, State = #parser_state{headers = Headers, hd
parse(hdvalue, <>, State) ->
parse(hdvalue, Rest, acc(Ch, State)).
+%% @private
parse(body, <<>>, State, Length) ->
- {more, fun(Bin) -> parse(body, Bin, State, Length) end};
+ {more, #{phase => body, length => Length, state => State}};
parse(body, Bin, State, none) ->
case binary:split(Bin, <>) of
[Chunk, Rest] ->
{ok, new_frame(acc(Chunk, State)), Rest};
[Chunk] ->
- {more, fun(More) -> parse(body, More, acc(Chunk, State), none) end}
+ {more, #{phase => body, length => none, state => acc(Chunk, State)}}
end;
parse(body, Bin, State, Len) when byte_size(Bin) >= (Len+1) ->
<> = Bin,
{ok, new_frame(acc(Chunk, State)), Rest};
parse(body, Bin, State, Len) ->
- {more, fun(More) -> parse(body, More, acc(Bin, State), Len - byte_size(Bin)) end}.
+ {more, #{phase => body, length => Len - byte_size(Bin), state => acc(Bin, State)}}.
add_header(Name, Value, Headers) ->
case lists:keyfind(Name, 1, Headers) of
diff --git a/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl b/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl
index 22e1c3eb3..79cc8f435 100644
--- a/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_heartbeat.erl
@@ -19,88 +19,74 @@
-include("emqx_stomp.hrl").
--export([ start_link/2
- , stop/1
- ]).
-
-%% callback
-export([ init/1
- , loop/3
+ , check/3
+ , info/1
+ , interval/2
]).
--define(MAX_REPEATS, 1).
+-record(heartbeater, {interval, statval, repeat}).
--record(heartbeater, {name, cycle, tref, val, statfun, action, repeat = 0}).
+-type name() :: incoming | outgoing.
-start_link({0, _, _}, {0, _, _}) ->
- {ok, none};
+-type heartbeat() :: #{incoming => #heartbeater{},
+ outgoing => #heartbeater{}
+ }.
-start_link(Incoming, Outgoing) ->
- Params = [self(), Incoming, Outgoing],
- {ok, spawn_link(?MODULE, init, [Params])}.
-stop(Pid) ->
- Pid ! stop.
+%%--------------------------------------------------------------------
+%% APIs
+%%--------------------------------------------------------------------
-init([Parent, Incoming, Outgoing]) ->
- loop(Parent, heartbeater(incomming, Incoming), heartbeater(outgoing, Outgoing)).
+-spec init({non_neg_integer(), non_neg_integer()}) -> heartbeat().
+init({0, 0}) ->
+ #{};
+init({Cx, Cy}) ->
+ maps:filter(fun(_, V) -> V /= undefined end,
+ #{incoming => heartbeater(Cx),
+ outgoing => heartbeater(Cy)
+ }).
-heartbeater(_, {0, _, _}) ->
+heartbeater(0) ->
undefined;
+heartbeater(I) ->
+ #heartbeater{
+ interval = I,
+ statval = 0,
+ repeat = 0
+ }.
-heartbeater(InOut, {Cycle, StatFun, ActionFun}) ->
- {ok, Val} = StatFun(),
- #heartbeater{name = InOut, cycle = Cycle,
- tref = timer(InOut, Cycle),
- val = Val, statfun = StatFun,
- action = ActionFun}.
-
-loop(Parent, Incomming, Outgoing) ->
- receive
- {heartbeat, incomming} ->
- #heartbeater{val = LastVal, statfun = StatFun,
- action = Action, repeat = Repeat} = Incomming,
- case StatFun() of
- {ok, Val} ->
- if Val =/= LastVal ->
- hibernate([Parent, resume(Incomming, Val), Outgoing]);
- Repeat < ?MAX_REPEATS ->
- hibernate([Parent, resume(Incomming, Val, Repeat+1), Outgoing]);
- true ->
- Action()
- end;
- {error, Error} -> %% einval
- exit({shutdown, Error})
- end;
- {heartbeat, outgoing} ->
- #heartbeater{val = LastVal, statfun = StatFun, action = Action} = Outgoing,
- case StatFun() of
- {ok, Val} ->
- if Val =:= LastVal ->
- Action(), {ok, NewVal} = StatFun(),
- hibernate([Parent, Incomming, resume(Outgoing, NewVal)]);
- true ->
- hibernate([Parent, Incomming, resume(Outgoing, Val)])
- end;
- {error, Error} -> %% einval
- exit({shutdown, Error})
- end;
- stop ->
- ok;
- _Other ->
- loop(Parent, Incomming, Outgoing)
+-spec check(name(), pos_integer(), heartbeat())
+ -> {ok, heartbeat()}
+ | {error, timeout}.
+check(Name, NewVal, HrtBt) ->
+ HrtBter = maps:get(Name, HrtBt),
+ case check(NewVal, HrtBter) of
+ {error, _} = R -> R;
+ {ok, NHrtBter} ->
+ {ok, HrtBt#{Name => NHrtBter}}
end.
-resume(Hb, NewVal) ->
- resume(Hb, NewVal, 0).
-resume(Hb = #heartbeater{name = InOut, cycle = Cycle}, NewVal, Repeat) ->
- Hb#heartbeater{tref = timer(InOut, Cycle), val = NewVal, repeat = Repeat}.
+check(NewVal, HrtBter = #heartbeater{statval = OldVal,
+ repeat = Repeat}) ->
+ if
+ NewVal =/= OldVal ->
+ {ok, HrtBter#heartbeater{statval = NewVal, repeat = 0}};
+ Repeat < 1 ->
+ {ok, HrtBter#heartbeater{repeat = Repeat + 1}};
+ true -> {error, timeout}
+ end.
-timer(_InOut, 0) ->
- undefined;
-timer(InOut, Cycle) ->
- erlang:send_after(Cycle, self(), {heartbeat, InOut}).
-
-hibernate(Args) ->
- erlang:hibernate(?MODULE, loop, Args).
+-spec info(heartbeat()) -> map().
+info(HrtBt) ->
+ maps:map(fun(_, #heartbeater{interval = Intv,
+ statval = Val,
+ repeat = Repeat}) ->
+ #{interval => Intv, statval => Val, repeat => Repeat}
+ end, HrtBt).
+interval(Type, HrtBt) ->
+ case maps:get(Type, HrtBt, undefined) of
+ undefined -> undefined;
+ #heartbeater{interval = Intv} -> Intv
+ end.
diff --git a/apps/emqx_stomp/src/emqx_stomp_protocol.erl b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
index fa75f08e3..4834955a2 100644
--- a/apps/emqx_stomp/src/emqx_stomp_protocol.erl
+++ b/apps/emqx_stomp/src/emqx_stomp_protocol.erl
@@ -18,50 +18,58 @@
-module(emqx_stomp_protocol).
-include("emqx_stomp.hrl").
+
-include_lib("emqx/include/emqx.hrl").
+-include_lib("emqx/include/logger.hrl").
-include_lib("emqx/include/emqx_mqtt.hrl").
+-logger_header("[Stomp-Proto]").
+
-import(proplists, [get_value/2, get_value/3]).
%% API
--export([ init/3
+-export([ init/2
, info/1
]).
-export([ received/2
, send/2
, shutdown/2
+ , timeout/3
]).
--record(stomp_proto, {peername,
- sendfun,
- connected = false,
- proto_ver,
- proto_name,
- heart_beats,
- login,
- allow_anonymous,
- default_user,
- subscriptions = []}).
+-record(stomp_proto, {
+ peername,
+ heartfun,
+ sendfun,
+ connected = false,
+ proto_ver,
+ proto_name,
+ heart_beats,
+ login,
+ allow_anonymous,
+ default_user,
+ subscriptions = [],
+ timers :: #{atom() => disable | undefined | reference()}
+ }).
+
+-define(TIMER_TABLE, #{
+ incoming_timer => incoming,
+ outgoing_timer => outgoing
+ }).
-type(stomp_proto() :: #stomp_proto{}).
--define(LOG(Level, Format, Args, State),
- emqx_logger:Level("Stomp(~s): " ++ Format, [esockd:format(State#stomp_proto.peername) | Args])).
-
--define(record_to_proplist(Def, Rec),
- lists:zip(record_info(fields, Def), tl(tuple_to_list(Rec)))).
-
--define(record_to_proplist(Def, Rec, Fields),
- [{K, V} || {K, V} <- ?record_to_proplist(Def, Rec),
- lists:member(K, Fields)]).
-
%% @doc Init protocol
-init(Peername, SendFun, Env) ->
+init(#{peername := Peername,
+ sendfun := SendFun,
+ heartfun := HeartFun}, Env) ->
AllowAnonymous = get_value(allow_anonymous, Env, false),
DefaultUser = get_value(default_user, Env),
#stomp_proto{peername = Peername,
+ heartfun = HeartFun,
sendfun = SendFun,
+ timers = #{},
allow_anonymous = AllowAnonymous,
default_user = DefaultUser}.
@@ -78,9 +86,10 @@ info(#stomp_proto{connected = Connected,
{login, Login},
{subscriptions, Subscriptions}].
--spec(received(stomp_frame(), stomp_proto()) -> {ok, stomp_proto()}
- | {error, any(), stomp_proto()}
- | {stop, any(), stomp_proto()}).
+-spec(received(stomp_frame(), stomp_proto())
+ -> {ok, stomp_proto()}
+ | {error, any(), stomp_proto()}
+ | {stop, any(), stomp_proto()}).
received(Frame = #stomp_frame{command = <<"STOMP">>}, State) ->
received(Frame#stomp_frame{command = <<"CONNECT">>}, State);
@@ -92,12 +101,11 @@ received(#stomp_frame{command = <<"CONNECT">>, headers = Headers},
Passc = header(<<"passcode">>, Headers),
case check_login(Login, Passc, AllowAnonymous, DefaultUser) of
true ->
- Heartbeats = header(<<"heart-beat">>, Headers, <<"0,0">>),
- self() ! {heartbeat, start, parse_heartbeats(Heartbeats)},
- NewState = State#stomp_proto{connected = true, proto_ver = Version,
- heart_beats = Heartbeats, login = Login},
+ Heartbeats = parse_heartbeats(header(<<"heart-beat">>, Headers, <<"0,0">>)),
+ NState = start_heartbeart_timer(Heartbeats, State#stomp_proto{connected = true,
+ proto_ver = Version, login = Login}),
send(connected_frame([{<<"version">>, Version},
- {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NewState);
+ {<<"heart-beat">>, reverse_heartbeats(Heartbeats)}]), NState);
false ->
send(error_frame(undefined, <<"Login or passcode error!">>), State),
{error, login_or_passcode_error, State}
@@ -206,8 +214,8 @@ received(#stomp_frame{command = <<"BEGIN">>, headers = Headers}, State) ->
received(#stomp_frame{command = <<"COMMIT">>, headers = Headers}, State) ->
Id = header(<<"transaction">>, Headers),
case emqx_stomp_transaction:commit(Id, State) of
- {ok, NewState} ->
- maybe_send_receipt(receipt_id(Headers), NewState);
+ {ok, NState} ->
+ maybe_send_receipt(receipt_id(Headers), NState);
{error, not_found} ->
send(error_frame(receipt_id(Headers), ["Transaction ", Id, " not found"]), State)
end;
@@ -248,17 +256,40 @@ send(Msg = #message{topic = Topic, headers = Headers, payload = Payload},
body = Payload},
send(Frame, State);
false ->
- ?LOG(error, "Stomp dropped: ~p", [Msg], State),
+ ?LOG(error, "Stomp dropped: ~p", [Msg]),
{error, dropped, State}
end;
-send(Frame, State = #stomp_proto{sendfun = SendFun}) ->
- ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)], State),
+send(Frame, State = #stomp_proto{sendfun = {Fun, Args}}) ->
+ ?LOG(info, "SEND Frame: ~s", [emqx_stomp_frame:format(Frame)]),
Data = emqx_stomp_frame:serialize(Frame),
- ?LOG(debug, "SEND ~p", [Data], State),
- SendFun(Data),
+ ?LOG(debug, "SEND ~p", [Data]),
+ erlang:apply(Fun, [Data] ++ Args),
{ok, State}.
+shutdown(_Reason, _State) ->
+ ok.
+
+timeout(_TRef, {incoming, NewVal},
+ State = #stomp_proto{heart_beats = HrtBt}) ->
+ case emqx_stomp_heartbeat:check(incoming, NewVal, HrtBt) of
+ {error, timeout} ->
+ {shutdown, heartbeat_timeout, State};
+ {ok, NHrtBt} ->
+ {ok, reset_timer(incoming_timer, State#stomp_proto{heart_beats = NHrtBt})}
+ end;
+
+timeout(_TRef, {outgoing, NewVal},
+ State = #stomp_proto{heart_beats = HrtBt,
+ heartfun = {Fun, Args}}) ->
+ case emqx_stomp_heartbeat:check(outgoing, NewVal, HrtBt) of
+ {error, timeout} ->
+ _ = erlang:apply(Fun, Args),
+ {ok, State};
+ {ok, NHrtBt} ->
+ {ok, reset_timer(outgoing_timer, State#stomp_proto{heart_beats = NHrtBt})}
+ end.
+
negotiate_version(undefined) ->
{ok, <<"1.0">>};
negotiate_version(Accepts) ->
@@ -322,17 +353,6 @@ error_frame(Headers, undefined, Msg) ->
error_frame(Headers, ReceiptId, Msg) ->
emqx_stomp_frame:make(<<"ERROR">>, [{<<"receipt-id">>, ReceiptId} | Headers], Msg).
-parse_heartbeats(Heartbeats) ->
- CxCy = re:split(Heartbeats, <<",">>, [{return, list}]),
- list_to_tuple([list_to_integer(S) || S <- CxCy]).
-
-reverse_heartbeats(Heartbeats) ->
- CxCy = re:split(Heartbeats, <<",">>, [{return, list}]),
- list_to_binary(string:join(lists:reverse(CxCy), ",")).
-
-shutdown(_Reason, _State) ->
- ok.
-
next_msgid() ->
MsgId = case get(msgid) of
undefined -> 1;
@@ -363,3 +383,52 @@ make_mqtt_message(Topic, Headers, Body) ->
receipt_id(Headers) ->
header(<<"receipt">>, Headers).
+%%--------------------------------------------------------------------
+%% Heartbeat
+
+parse_heartbeats(Heartbeats) ->
+ CxCy = re:split(Heartbeats, <<",">>, [{return, list}]),
+ list_to_tuple([list_to_integer(S) || S <- CxCy]).
+
+reverse_heartbeats({Cx, Cy}) ->
+ iolist_to_binary(io_lib:format("~w,~w", [Cy, Cx])).
+
+start_heartbeart_timer(Heartbeats, State) ->
+ ensure_timer(
+ [incoming_timer, outgoing_timer],
+ State#stomp_proto{heart_beats = emqx_stomp_heartbeat:init(Heartbeats)}).
+
+%%--------------------------------------------------------------------
+%% Timer
+
+ensure_timer([Name], State) ->
+ ensure_timer(Name, State);
+ensure_timer([Name | Rest], State) ->
+ ensure_timer(Rest, ensure_timer(Name, State));
+
+ensure_timer(Name, State = #stomp_proto{timers = Timers}) ->
+ TRef = maps:get(Name, Timers, undefined),
+ Time = interval(Name, State),
+ case TRef == undefined andalso is_integer(Time) andalso Time > 0 of
+ true -> ensure_timer(Name, Time, State);
+ false -> State %% Timer disabled or exists
+ end.
+
+ensure_timer(Name, Time, State = #stomp_proto{timers = Timers}) ->
+ Msg = maps:get(Name, ?TIMER_TABLE),
+ TRef = emqx_misc:start_timer(Time, Msg),
+ State#stomp_proto{timers = Timers#{Name => TRef}}.
+
+reset_timer(Name, State) ->
+ ensure_timer(Name, clean_timer(Name, State)).
+
+reset_timer(Name, Time, State) ->
+ ensure_timer(Name, Time, clean_timer(Name, State)).
+
+clean_timer(Name, State = #stomp_proto{timers = Timers}) ->
+ State#stomp_proto{timers = maps:remove(Name, Timers)}.
+
+interval(incoming_timer, #stomp_proto{heart_beats = HrtBt}) ->
+ emqx_stomp_heartbeat:interval(incoming, HrtBt);
+interval(outgoing_timer, #stomp_proto{heart_beats = HrtBt}) ->
+ emqx_stomp_heartbeat:interval(outgoing, HrtBt).
diff --git a/apps/emqx_stomp/test/emqx_stomp_SUITE.erl b/apps/emqx_stomp/test/emqx_stomp_SUITE.erl
index d8b5cc5b2..ca46762ed 100644
--- a/apps/emqx_stomp/test/emqx_stomp_SUITE.erl
+++ b/apps/emqx_stomp/test/emqx_stomp_SUITE.erl
@@ -100,7 +100,7 @@ t_heartbeat(_) ->
{<<"host">>, <<"127.0.0.1:61613">>},
{<<"login">>, <<"guest">>},
{<<"passcode">>, <<"guest">>},
- {<<"heart-beat">>, <<"500,800">>}])),
+ {<<"heart-beat">>, <<"1000,800">>}])),
{ok, Data} = gen_tcp:recv(Sock, 0),
{ok, #stomp_frame{command = <<"CONNECTED">>,
headers = _,
@@ -345,5 +345,5 @@ parse(Data) ->
ProtoEnv = [{max_headers, 10},
{max_header_length, 1024},
{max_body_length, 8192}],
- ParseFun = emqx_stomp_frame:parser(ProtoEnv),
- ParseFun(Data).
+ Parser = emqx_stomp_frame:init_parer_state(ProtoEnv),
+ emqx_stomp_frame:parse(Data, Parser).
diff --git a/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl b/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
new file mode 100644
index 000000000..0d01bfcd4
--- /dev/null
+++ b/apps/emqx_stomp/test/emqx_stomp_heartbeat_SUITE.erl
@@ -0,0 +1,53 @@
+%%--------------------------------------------------------------------
+%% Copyright (c) 2020 EMQ Technologies Co., Ltd. All Rights Reserved.
+%%
+%% Licensed under the Apache License, Version 2.0 (the "License");
+%% you may not use this file except in compliance with the License.
+%% You may obtain a copy of the License at
+%%
+%% http://www.apache.org/licenses/LICENSE-2.0
+%%
+%% Unless required by applicable law or agreed to in writing, software
+%% distributed under the License is distributed on an "AS IS" BASIS,
+%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+%% See the License for the specific language governing permissions and
+%% limitations under the License.
+%%--------------------------------------------------------------------
+
+-module(emqx_stomp_heartbeat_SUITE).
+
+-compile(export_all).
+-compile(nowarn_export_all).
+
+all() -> emqx_ct:all(?MODULE).
+
+%%--------------------------------------------------------------------
+%% Test Cases
+%%--------------------------------------------------------------------
+
+t_init(_) ->
+ #{} = emqx_stomp_heartbeat:init({0, 0}),
+ #{incoming := _} = emqx_stomp_heartbeat:init({1, 0}),
+ #{outgoing := _} = emqx_stomp_heartbeat:init({0, 1}).
+
+t_check_1(_) ->
+ HrtBt = emqx_stomp_heartbeat:init({1, 1}),
+ {ok, HrtBt1} = emqx_stomp_heartbeat:check(incoming, 0, HrtBt),
+ {error, timeout} = emqx_stomp_heartbeat:check(incoming, 0, HrtBt1),
+
+ {ok, HrtBt2} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt1),
+ {error, timeout} = emqx_stomp_heartbeat:check(outgoing, 0, HrtBt2),
+ ok.
+
+t_check_2(_) ->
+ HrtBt = emqx_stomp_heartbeat:init({1, 0}),
+ #{incoming := _} = lists:foldl(fun(I, Acc) ->
+ {ok, NAcc} = emqx_stomp_heartbeat:check(incoming, I, Acc),
+ NAcc
+ end, HrtBt, lists:seq(1,1000)),
+ ok.
+
+t_info(_) ->
+ HrtBt = emqx_stomp_heartbeat:init({100, 100}),
+ #{incoming := _,
+ outgoing := _} = emqx_stomp_heartbeat:info(HrtBt).