diff --git a/executor/cte.go b/executor/cte.go index 41709dc0bb3b0..8acde6123cc14 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -86,16 +86,16 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { defer e.producer.resTbl.Unlock() if e.producer.checkAndUpdateCorColHashCode() { - e.producer.reset() - if err = e.producer.reopenTbls(); err != nil { + err = e.producer.reset() + if err != nil { return err } } if e.producer.openErr != nil { return e.producer.openErr } - if !e.producer.opened { - if err = e.producer.openProducer(ctx, e); err != nil { + if !e.producer.hasCTEResult() && !e.producer.executorOpened { + if err = e.producer.openProducerExecutor(ctx, e); err != nil { return err } } @@ -106,8 +106,14 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() - if !e.producer.resTbl.Done() { - if err = e.producer.produce(ctx); err != nil { + if !e.producer.hasCTEResult() { + // in case that another CTEExec call close without generate CTE result. + if !e.producer.executorOpened { + if err = e.producer.openProducerExecutor(ctx, e); err != nil { + return err + } + } + if err = e.producer.genCTEResult(ctx); err != nil { return err } } @@ -129,7 +135,7 @@ func (e *CTEExec) Close() (firstErr error) { func() { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() - if !e.producer.closed { + if e.producer.executorOpened { failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { ok := v.(bool) if ok { @@ -137,12 +143,17 @@ func (e *CTEExec) Close() (firstErr error) { panic(memory.PanicMemoryExceedWarnMsg) } }) - // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. - // It means you can still read resTbl after call closeProducer(). - // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). + // closeProducerExecutor() only close seedExec and recursiveExec, will not touch resTbl. + // It means you can still read resTbl after call closeProducerExecutor(). + // You can even call all three functions(openProducerExecutor/genCTEResult/closeProducerExecutor) in CTEExec.Next(). // Separating these three function calls is only to follow the abstraction of the volcano model. - err := e.producer.closeProducer() + err := e.producer.closeProducerExecutor() firstErr = setFirstErr(firstErr, err, "close cte producer error") + if !e.producer.hasCTEResult() { + // CTE result is not generated, in this case, we reset it + err = e.producer.reset() + firstErr = setFirstErr(firstErr, err, "close cte producer error") + } } }() err := e.baseExecutor.Close() @@ -157,10 +168,10 @@ func (e *CTEExec) reset() { } type cteProducer struct { - // opened should be false when not open or open fail(a.k.a. openErr != nil) - opened bool - produced bool - closed bool + // executorOpened is used to indicate whether the executor(seedExec/recursiveExec) is opened. + // when executorOpened is true, the executor is opened, otherwise it means the executor is + // not opened or is already closed. + executorOpened bool // cteProducer is shared by multiple operators, so if the first operator tries to open // and got error, the second should return open error directly instead of open again. @@ -199,14 +210,10 @@ type cteProducer struct { corColHashCodes [][]byte } -func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err error) { +func (p *cteProducer) openProducerExecutor(ctx context.Context, cteExec *CTEExec) (err error) { defer func() { p.openErr = err - if err == nil { - p.opened = true - } else { - p.opened = false - } + p.executorOpened = true }() if p.seedExec == nil { return errors.New("seedExec for CTEExec is nil") @@ -249,7 +256,7 @@ func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err e return nil } -func (p *cteProducer) closeProducer() (firstErr error) { +func (p *cteProducer) closeProducerExecutor() (firstErr error) { err := p.seedExec.Close() firstErr = setFirstErr(firstErr, err, "close seedExec err") if p.recursiveExec != nil { @@ -267,7 +274,7 @@ func (p *cteProducer) closeProducer() (firstErr error) { // because ExplainExec still needs tracker to get mem usage info. p.memTracker = nil p.diskTracker = nil - p.closed = true + p.executorOpened = false return } @@ -334,7 +341,13 @@ func (p *cteProducer) nextChunkLimit(cteExec *CTEExec, req *chunk.Chunk) error { return nil } -func (p *cteProducer) produce(ctx context.Context) (err error) { +func (p *cteProducer) hasCTEResult() bool { + return p.resTbl.Done() +} + +// genCTEResult generates the result of CTE, and stores the result in resTbl. +// This is a synchronous function, which means it will block until the result is generated. +func (p *cteProducer) genCTEResult(ctx context.Context) (err error) { if p.resTbl.Error() != nil { return p.resTbl.Error() } @@ -527,14 +540,18 @@ func (p *cteProducer) setupTblsForNewIteration() (err error) { return nil } -func (p *cteProducer) reset() { +func (p *cteProducer) reset() error { p.curIter = 0 p.hashTbl = nil - - p.opened = false + p.executorOpened = false p.openErr = nil - p.produced = false - p.closed = false + + // Normally we need to setup tracker after calling Reopen(), + // But reopen resTbl means we need to call genCTEResult() again, it will setup tracker. + if err := p.resTbl.Reopen(); err != nil { + return err + } + return p.iterInTbl.Reopen() } func (p *cteProducer) resetTracker() { @@ -548,18 +565,6 @@ func (p *cteProducer) resetTracker() { } } -func (p *cteProducer) reopenTbls() (err error) { - if p.isDistinct { - p.hashTbl = newConcurrentMapHashTable() - } - // Normally we need to setup tracker after calling Reopen(), - // But reopen resTbl means we need to call produce() again, it will setup tracker. - if err := p.resTbl.Reopen(); err != nil { - return err - } - return p.iterInTbl.Reopen() -} - // Check if tbl meets the requirement of limit. func (p *cteProducer) limitDone(tbl cteutil.Storage) bool { return p.hasLimit && uint64(tbl.NumRows()) >= p.limitEnd diff --git a/executor/issuetest/executor_issue_test.go b/executor/issuetest/executor_issue_test.go index 06ad38354c408..a9260c6b0183e 100644 --- a/executor/issuetest/executor_issue_test.go +++ b/executor/issuetest/executor_issue_test.go @@ -1463,3 +1463,22 @@ func TestIssue49902(t *testing.T) { tk.MustQuery("SELECT count(`t`.`c`) FROM (`s`) JOIN `t` GROUP BY `t`.`c`;").Check(testkit.Rows("170")) tk.MustExec("set @@tidb_max_chunk_size = default;") } + +func TestIssue55881(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists aaa;") + tk.MustExec("drop table if exists bbb;") + tk.MustExec("create table aaa(id int, value int);") + tk.MustExec("create table bbb(id int, value int);") + tk.MustExec("insert into aaa values(1,2),(2,3)") + tk.MustExec("insert into bbb values(1,2),(2,3),(3,4)") + // set tidb_executor_concurrency to 1 to let the issue happens with high probability. + tk.MustExec("set tidb_executor_concurrency=1;") + // this is a random issue, so run it 100 times to increase the probability of the issue. + for i := 0; i < 100; i++ { + tk.MustQuery("with cte as (select * from aaa) select id, (select id from (select * from aaa where aaa.id != bbb.id union all select * from cte union all select * from cte) d limit 1)," + + "(select max(value) from (select * from cte union all select * from cte union all select * from aaa where aaa.id > bbb.id)) from bbb;") + } +}