From a2860aa69c642a5a5be990851ac4faaab0d0293b Mon Sep 17 00:00:00 2001
From: Indrek Juhkam <indrek@urgas.eu>
Date: Sat, 29 Jun 2024 18:37:16 +0300
Subject: [PATCH] Fix folding over transactions

The brod:fold/8 did not work correctly when the last message in the
message set a transaction commit.

The output looked like this:
```
%%% brod_consumer_SUITE ==> t_fold_transactions: FAILED
%%% brod_consumer_SUITE ==> {function_clause,
    [{lists,last,[[]],[{file,"lists.erl"},{line,228}]},
     {brod_utils,handle_fetch_rsp,7,
         [{file,"/home/indrek/gems/brod/src/brod_utils.erl"},{line,661}]},
     {brod_consumer_SUITE,t_fold_transactions,1,
         [{file,"/home/indrek/gems/brod/test/brod_consumer_SUITE.erl"},
          {line,443}]},
     {test_server,ts_tc,3,[{file,"test_server.erl"},{line,1783}]},
     {test_server,run_test_case_eval1,6,
         [{file,"test_server.erl"},{line,1292}]},
     {test_server,run_test_case_eval,9,
         [{file,"test_server.erl"},{line,1224}]}]}
```

The issue was that `#kafka_message{offset = LastOffset} = lists:last(Msgs),`
was used, when the Msgs list was empty.

Now instead of trying to infer next offset from the kafka_message, we
pass the NextFetchOffset from the brod_utils:fetch_one_batch/4 function.

Fixes #588
---
 src/brod_utils.erl           | 74 +++++++++++++++++++++++++++---------
 test/brod_consumer_SUITE.erl | 21 ++++++++++
 2 files changed, 77 insertions(+), 18 deletions(-)

diff --git a/src/brod_utils.erl b/src/brod_utils.erl
index f7732a1a..e6525f01 100644
--- a/src/brod_utils.erl
+++ b/src/brod_utils.erl
@@ -31,6 +31,7 @@
         , epoch_ms/0
         , fetch/4
         , fetch/5
+        , fetch_one_batch/4
         , fold/8
         , fetch_committed_offsets/3
         , fetch_committed_offsets/4
@@ -66,6 +67,8 @@
 -type req_fun() :: fun((offset(), kpro:count()) -> kpro:req()).
 -type fetch_fun() :: fun((offset()) -> {ok, {offset(), [brod:message()]}} |
                                        {error, any()}).
+-type fetch_fun2() :: fun((offset()) -> {ok, {offset(), offset(), [brod:message()]}} |
+                                       {error, any()}).
 -type connection() :: kpro:connection().
 -type conn_config() :: brod:conn_config().
 -type topic() :: brod:topic().
@@ -331,7 +334,7 @@ fold(Client, Topic, Partition, Offset, Opts,
       ?BROD_FOLD_RET(Acc, Offset, {error, Reason})
   end;
 fold(Conn, Topic, Partition, Offset, Opts, Acc, Fun, Limits) ->
-  Fetch = make_fetch_fun(Conn, Topic, Partition, Opts),
+  Fetch = make_fetch_fun2(Conn, Topic, Partition, Opts),
   Infinity = 1 bsl 64,
   EndOffset = maps:get(reach_offset, Limits, Infinity),
   CountLimit = maps:get(message_count, Limits, Infinity),
@@ -346,13 +349,21 @@ fold(Conn, Topic, Partition, Offset, Opts, Acc, Fun, Limits) ->
 -spec make_fetch_fun(pid(), topic(), partition(), brod:fetch_opts()) ->
         fetch_fun().
 make_fetch_fun(Conn, Topic, Partition, FetchOpts) ->
+  make_fetch_fun(Conn, Topic, Partition, FetchOpts, fun fetch/4).
+
+-spec make_fetch_fun2(pid(), topic(), partition(), brod:fetch_opts()) ->
+        fetch_fun2().
+make_fetch_fun2(Conn, Topic, Partition, FetchOpts) ->
+  make_fetch_fun(Conn, Topic, Partition, FetchOpts, fun fetch_one_batch/4).
+
+make_fetch_fun(Conn, Topic, Partition, FetchOpts, FetchFun) ->
   WaitTime = maps:get(max_wait_time, FetchOpts, 1000),
   MinBytes = maps:get(min_bytes, FetchOpts, 1),
   MaxBytes = maps:get(max_bytes, FetchOpts, 1 bsl 20),
   IsolationLevel = maps:get(isolation_level, FetchOpts, ?kpro_read_committed),
   ReqFun = make_req_fun(Conn, Topic, Partition, WaitTime,
                         MinBytes, IsolationLevel),
-  fun(Offset) -> ?MODULE:fetch(Conn, ReqFun, Offset, MaxBytes) end.
+  fun(Offset) -> FetchFun(Conn, ReqFun, Offset, MaxBytes) end.
 
 -spec make_part_fun(brod:partitioner()) -> brod:partition_fun().
 make_part_fun(random) ->
@@ -445,12 +456,37 @@ do_fetch_committed_offsets(Conn, GroupId, Topics) when is_pid(Conn) ->
 -spec fetch(connection(), req_fun(), offset(), kpro:count()) ->
                {ok, {offset(), [brod:message()]}} | {error, any()}.
 fetch(Conn, ReqFun, Offset, MaxBytes) ->
+  case do_fetch(Conn, ReqFun, Offset, MaxBytes) of
+    {ok, {StableOffset, _NextOffset, Msgs}} ->
+      {ok, {StableOffset, Msgs}}; %% for backward compatibility
+    Other ->
+      Other
+  end.
+
+%% @doc Fetch a message-set. If the given MaxBytes is not enough to fetch a
+%% single message, expand it to fetch exactly one message
+%% The fetch/4 may return an empty batch even if there can be more messages in
+%% the topic. This function returns a non-empty batch unless the stable offset
+%% is reached.
+-spec fetch_one_batch(connection(), req_fun(), offset(), kpro:count()) ->
+        {ok, {offset(), offset(), [brod:message()]}} | {error, any()}.
+fetch_one_batch(Conn, ReqFun, Offset, MaxBytes) ->
+  case do_fetch(Conn, ReqFun, Offset, MaxBytes) of
+    {ok, {StableOffset, NextOffset, []}} when NextOffset < StableOffset ->
+      fetch_one_batch(Conn, ReqFun, NextOffset, MaxBytes);
+    Other ->
+      Other
+  end.
+
+-spec do_fetch(connection(), req_fun(), offset(), kpro:count()) ->
+        {ok, {offset(), offset(), [brod:message()]}} | {error, any()}.
+do_fetch(Conn, ReqFun, Offset, MaxBytes) ->
   Request = ReqFun(Offset, MaxBytes),
   case request_sync(Conn, Request, infinity) of
     {ok, #{error_code := ErrorCode}} when ?IS_ERROR(ErrorCode) ->
       {error, ErrorCode};
     {ok, #{batches := ?incomplete_batch(Size)}} ->
-      fetch(Conn, ReqFun, Offset, Size);
+      do_fetch(Conn, ReqFun, Offset, Size);
     {ok, #{header := Header, batches := Batches}} ->
       StableOffset = get_stable_offset(Header),
       {NewBeginOffset0, Msgs} = flatten_batches(Offset, Header, Batches),
@@ -472,9 +508,9 @@ fetch(Conn, ReqFun, Offset, MaxBytes) ->
                 %% we can only bump begin_offset with +1 and try again.
                 NewBeginOffset0 + 1
             end,
-          fetch(Conn, ReqFun, NewBeginOffset, MaxBytes);
+          do_fetch(Conn, ReqFun, NewBeginOffset, MaxBytes);
         false ->
-          {ok, {StableOffset, Msgs}}
+          {ok, {StableOffset, NewBeginOffset0, Msgs}}
       end;
     {error, Reason} ->
       {error, Reason}
@@ -636,32 +672,34 @@ do_fold(Spawn, {Pid, Mref}, Offset, Acc, Fun, End, Count) ->
 
 handle_fetch_rsp(_Spawn, {error, Reason}, Offset, Acc, _Fun, _, _) ->
   ?BROD_FOLD_RET(Acc, Offset, {fetch_failure, Reason});
-handle_fetch_rsp(_Spawn, {ok, {StableOffset, []}}, Offset, Acc, _Fun,
+handle_fetch_rsp(_Spawn, {ok, {StableOffset, _NextFetchOffset, []}}, Offset, Acc, _Fun,
                 _End, _Count) when Offset >= StableOffset ->
   ?BROD_FOLD_RET(Acc, Offset, reached_end_of_partition);
-handle_fetch_rsp(Spawn, {ok, {_StableOffset, Msgs}}, Offset, Acc, Fun,
+handle_fetch_rsp(Spawn, {ok, {_StableOffset, NextFetchOffset, Msgs}}, Offset, Acc, Fun,
                  End, Count) ->
-  #kafka_message{offset = LastOffset} = lists:last(Msgs),
-  %% start fetching the next batch if not stopping at current
-  Fetcher = case LastOffset < End andalso length(Msgs) < Count of
-              true -> Spawn(LastOffset + 1);
+  Fetcher = case NextFetchOffset =< End andalso length(Msgs) < Count of
+              true -> Spawn(NextFetchOffset);
               false -> undefined
             end,
-  do_acc(Spawn, Fetcher, Offset, Acc, Fun, Msgs, End, Count).
+  do_acc(Spawn, Fetcher, NextFetchOffset, Offset, Acc, Fun, Msgs, End, Count).
 
-do_acc(_Spawn, Fetcher, Offset, Acc, _Fun, _, _End, 0) ->
+do_acc(_Spawn, Fetcher, _NextFetchOffset, Offset, Acc, _Fun, _, _End, 0) ->
   undefined = Fetcher, %% assert
   ?BROD_FOLD_RET(Acc, Offset, reached_message_count_limit);
-do_acc(_Spawn, Fetcher, Offset, Acc, _Fun, _, End, _Count) when Offset > End ->
+do_acc(_Spawn, Fetcher, _NextFetchOffset, Offset, Acc, _Fun, _, End, _Count) when Offset > End  ->
   undefined = Fetcher, %% assert
   ?BROD_FOLD_RET(Acc, Offset, reached_target_offset);
-do_acc(Spawn, Fetcher, Offset, Acc, Fun, [], End, Count) ->
-  do_fold(Spawn, Fetcher, Offset, Acc, Fun, End, Count);
-do_acc(Spawn, Fetcher, Offset, Acc, Fun, [Msg | Rest], End, Count) ->
+do_acc(_Spawn, Fetcher, NextFetchOffset, _Offset, Acc, _Fun, [], End, _Count)
+        when NextFetchOffset > End ->
+  undefined = Fetcher, %% assert
+  ?BROD_FOLD_RET(Acc, NextFetchOffset, reached_target_offset);
+do_acc(Spawn, Fetcher, NextFetchOffset, _Offset, Acc, Fun, [], End, Count) ->
+  do_fold(Spawn, Fetcher, NextFetchOffset, Acc, Fun, End, Count);
+do_acc(Spawn, Fetcher, NextFetchOffset, Offset, Acc, Fun, [Msg | Rest], End, Count) ->
   try Fun(Msg, Acc) of
     {ok, NewAcc} ->
       NextOffset = Msg#kafka_message.offset + 1,
-      do_acc(Spawn, Fetcher, NextOffset, NewAcc, Fun, Rest, End, Count - 1);
+      do_acc(Spawn, Fetcher, NextFetchOffset, NextOffset, NewAcc, Fun, Rest, End, Count - 1);
     {error, Reason} ->
       ok = kill_fetcher(Fetcher),
       ?BROD_FOLD_RET(Acc, Offset, Reason)
diff --git a/test/brod_consumer_SUITE.erl b/test/brod_consumer_SUITE.erl
index 93774d1a..77e602c0 100644
--- a/test/brod_consumer_SUITE.erl
+++ b/test/brod_consumer_SUITE.erl
@@ -32,6 +32,7 @@
         , t_fetch_aborted_from_the_middle/1
         , t_direct_fetch/1
         , t_fold/1
+        , t_fold_transactions/1
         , t_direct_fetch_with_small_max_bytes/1
         , t_direct_fetch_expand_max_bytes/1
         , t_resolve_offset/1
@@ -422,6 +423,26 @@ t_fold(Config) when is_list(Config) ->
               0, ErrorFoldF, #{})),
   ok.
 
+t_fold_transactions(kafka_version_match) ->
+  has_txn();
+t_fold_transactions(Config) when is_list(Config) ->
+  Client = ?config(client),
+  Topic = ?TOPIC,
+  Partition = 0,
+  Batch = [#{value => <<"one">>}, #{value => <<"two">>}],
+  {ok, Tx} = brod:transaction(Client, <<"some_transaction">>, []),
+  {ok, Offset} = brod:txn_produce(Tx, ?TOPIC, Partition, Batch),
+  ok = brod:commit(Tx),
+  FoldF =
+    fun F(#kafka_message{value = V}, Acc) -> {ok, F(V, Acc)};
+        F(V, Acc) -> [V | Acc]
+    end,
+  FetchOpts = #{max_bytes => 1},
+  ?assertMatch({Result, O, reached_end_of_partition}
+                when O =:= Offset + length(Batch) + 1 andalso length(Result) =:= 2,
+    brod:fold(Client, Topic, Partition, Offset, FetchOpts, [], FoldF, #{})),
+  ok.
+
 %% This test case does not work with Kafka 0.9, not sure aobut 0.10 and 0.11
 %% since all 0.x versions are old enough, we only try to verify this against
 %% 1.x or newer