forked from tinygo-org/tinygo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
goroutine-lowering.go
588 lines (531 loc) · 23.7 KB
/
goroutine-lowering.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
package compiler
// This file lowers goroutine pseudo-functions into coroutines scheduled by a
// scheduler at runtime. It uses coroutine support in LLVM for this
// transformation: https://llvm.org/docs/Coroutines.html
//
// For example, take the following code:
//
// func main() {
// go foo()
// time.Sleep(2 * time.Second)
// println("some other operation")
// bar()
// println("done")
// }
//
// func foo() {
// for {
// println("foo!")
// time.Sleep(time.Second)
// }
// }
//
// func bar() {
// time.Sleep(time.Second)
// println("blocking operation completed)
// }
//
// It is transformed by the IR generator in compiler.go into the following
// pseudo-Go code:
//
// func main() {
// fn := runtime.makeGoroutine(foo)
// fn()
// time.Sleep(2 * time.Second)
// println("some other operation")
// bar() // imagine an 'await' keyword in front of this call
// println("done")
// }
//
// func foo() {
// for {
// println("foo!")
// time.Sleep(time.Second)
// }
// }
//
// func bar() {
// time.Sleep(time.Second)
// println("blocking operation completed)
// }
//
// The pass in this file transforms this code even further, to the following
// async/await style pseudocode:
//
// func main(parent) {
// hdl := llvm.makeCoroutine()
// foo(nil) // do not pass the parent coroutine: this is an independent goroutine
// runtime.sleepTask(hdl, 2 * time.Second) // ask the scheduler to re-activate this coroutine at the right time
// llvm.suspend(hdl) // suspend point
// println("some other operation")
// bar(hdl) // await, pass a continuation (hdl) to bar
// llvm.suspend(hdl) // suspend point, wait for the callee to re-activate
// println("done")
// runtime.activateTask(parent) // re-activate the parent (nop, there is no parent)
// }
//
// func foo(parent) {
// hdl := llvm.makeCoroutine()
// for {
// println("foo!")
// runtime.sleepTask(hdl, time.Second) // ask the scheduler to re-activate this coroutine at the right time
// llvm.suspend(hdl) // suspend point
// }
// }
//
// func bar(parent) {
// hdl := llvm.makeCoroutine()
// runtime.sleepTask(hdl, time.Second) // ask the scheduler to re-activate this coroutine at the right time
// llvm.suspend(hdl) // suspend point
// println("blocking operation completed)
// runtime.activateTask(parent) // re-activate the parent coroutine before returning
// }
//
// The real LLVM code is more complicated, but this is the general idea.
//
// The LLVM coroutine passes will then process this file further transforming
// these three functions into coroutines. Most of the actual work is done by the
// scheduler, which runs in the background scheduling all coroutines.
import (
"errors"
"strings"
"tinygo.org/x/go-llvm"
)
type asyncFunc struct {
taskHandle llvm.Value
cleanupBlock llvm.BasicBlock
suspendBlock llvm.BasicBlock
unreachableBlock llvm.BasicBlock
}
// LowerGoroutines is a pass called during optimization that transforms the IR
// into one where all blocking functions are turned into goroutines and blocking
// calls into await calls.
func (c *Compiler) LowerGoroutines() error {
needsScheduler, err := c.markAsyncFunctions()
if err != nil {
return err
}
uses := getUses(c.mod.NamedFunction("runtime.callMain"))
if len(uses) != 1 || uses[0].IsACallInst().IsNil() {
panic("expected exactly 1 call of runtime.callMain, check the entry point")
}
mainCall := uses[0]
// Replace call of runtime.callMain() with a real call to main.main(),
// optionally followed by a call to runtime.scheduler().
c.builder.SetInsertPointBefore(mainCall)
realMain := c.mod.NamedFunction(c.ir.MainPkg().Pkg.Path() + ".main")
c.builder.CreateCall(realMain, []llvm.Value{llvm.Undef(c.i8ptrType), llvm.ConstPointerNull(c.i8ptrType)}, "")
if needsScheduler {
c.createRuntimeCall("scheduler", nil, "")
}
mainCall.EraseFromParentAsInstruction()
if !needsScheduler {
go_scheduler := c.mod.NamedFunction("go_scheduler")
if !go_scheduler.IsNil() {
// This is the WebAssembly backend.
// There is no need to export the go_scheduler function, but it is
// still exported. Make sure it is optimized away.
go_scheduler.SetLinkage(llvm.InternalLinkage)
}
}
// main.main was set to external linkage during IR construction. Set it to
// internal linkage to enable interprocedural optimizations.
realMain.SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.alloc").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.free").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.chanSend").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.chanRecv").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.sleepTask").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.activateTask").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.InternalLinkage)
return nil
}
// markAsyncFunctions does the bulk of the work of lowering goroutines. It
// determines whether a scheduler is needed, and if it is, it transforms
// blocking operations into goroutines and blocking calls into await calls.
//
// It does the following operations:
// * Find all blocking functions.
// * Determine whether a scheduler is necessary. If not, it skips the
// following operations.
// * Transform call instructions into await calls.
// * Transform return instructions into final suspends.
// * Set up the coroutine frames for async functions.
// * Transform blocking calls into their async equivalents.
func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
var worklist []llvm.Value
sleep := c.mod.NamedFunction("time.Sleep")
if !sleep.IsNil() {
worklist = append(worklist, sleep)
}
chanSendStub := c.mod.NamedFunction("runtime.chanSendStub")
if !chanSendStub.IsNil() {
worklist = append(worklist, chanSendStub)
}
chanRecvStub := c.mod.NamedFunction("runtime.chanRecvStub")
if !chanRecvStub.IsNil() {
worklist = append(worklist, chanRecvStub)
}
if len(worklist) == 0 {
// There are no blocking operations, so no need to transform anything.
return false, c.lowerMakeGoroutineCalls()
}
// Find all async functions.
// Keep reducing this worklist by marking a function as recursively async
// from the worklist and pushing all its parents that are non-async.
// This is somewhat similar to a worklist in a mark-sweep garbage collector:
// the work items are then grey objects.
asyncFuncs := make(map[llvm.Value]*asyncFunc)
asyncList := make([]llvm.Value, 0, 4)
for len(worklist) != 0 {
// Pick the topmost.
f := worklist[len(worklist)-1]
worklist = worklist[:len(worklist)-1]
if _, ok := asyncFuncs[f]; ok {
continue // already processed
}
// Add to set of async functions.
asyncFuncs[f] = &asyncFunc{}
asyncList = append(asyncList, f)
// Add all callees to the worklist.
for _, use := range getUses(f) {
if use.IsConstant() && use.Opcode() == llvm.BitCast {
bitcastUses := getUses(use)
for _, call := range bitcastUses {
if call.IsACallInst().IsNil() || call.CalledValue().Name() != "runtime.makeGoroutine" {
return false, errors.New("async function " + f.Name() + " incorrectly used in bitcast, expected runtime.makeGoroutine")
}
}
// This is a go statement. Do not mark the parent as async, as
// starting a goroutine is not a blocking operation.
continue
}
if use.IsACallInst().IsNil() {
// Not a call instruction. Maybe a store to a global? In any
// case, this requires support for async calls across function
// pointers which is not yet supported.
return false, errors.New("async function " + f.Name() + " used as function pointer")
}
parent := use.InstructionParent().Parent()
for i := 0; i < use.OperandsCount()-1; i++ {
if use.Operand(i) == f {
return false, errors.New("async function " + f.Name() + " used as function pointer in " + parent.Name())
}
}
worklist = append(worklist, parent)
}
}
// Check whether a scheduler is needed.
makeGoroutine := c.mod.NamedFunction("runtime.makeGoroutine")
if c.GOOS == "js" && strings.HasPrefix(c.Triple, "wasm") {
// JavaScript always needs a scheduler, as in general no blocking
// operations are possible. Blocking operations block the browser UI,
// which is very bad.
needsScheduler = true
} else {
// Only use a scheduler when an async goroutine is started. When the
// goroutine is not async (does not do any blocking operation), no
// scheduler is necessary as it can be called directly.
for _, use := range getUses(makeGoroutine) {
// Input param must be const bitcast of function.
bitcast := use.Operand(0)
if !bitcast.IsConstant() || bitcast.Opcode() != llvm.BitCast {
panic("expected const bitcast operand of runtime.makeGoroutine")
}
goroutine := bitcast.Operand(0)
if _, ok := asyncFuncs[goroutine]; ok {
needsScheduler = true
break
}
}
}
if !needsScheduler {
// No scheduler is needed. Do not transform all functions here.
// However, make sure that all go calls (which are all non-async) are
// transformed into regular calls.
return false, c.lowerMakeGoroutineCalls()
}
// Create a few LLVM intrinsics for coroutine support.
coroIdType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.ctx.Int32Type(), c.i8ptrType, c.i8ptrType, c.i8ptrType}, false)
coroIdFunc := llvm.AddFunction(c.mod, "llvm.coro.id", coroIdType)
coroSizeType := llvm.FunctionType(c.ctx.Int32Type(), nil, false)
coroSizeFunc := llvm.AddFunction(c.mod, "llvm.coro.size.i32", coroSizeType)
coroBeginType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.ctx.TokenType(), c.i8ptrType}, false)
coroBeginFunc := llvm.AddFunction(c.mod, "llvm.coro.begin", coroBeginType)
coroPromiseType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.i8ptrType, c.ctx.Int32Type(), c.ctx.Int1Type()}, false)
coroPromiseFunc := llvm.AddFunction(c.mod, "llvm.coro.promise", coroPromiseType)
coroSuspendType := llvm.FunctionType(c.ctx.Int8Type(), []llvm.Type{c.ctx.TokenType(), c.ctx.Int1Type()}, false)
coroSuspendFunc := llvm.AddFunction(c.mod, "llvm.coro.suspend", coroSuspendType)
coroEndType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType, c.ctx.Int1Type()}, false)
coroEndFunc := llvm.AddFunction(c.mod, "llvm.coro.end", coroEndType)
coroFreeType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.ctx.TokenType(), c.i8ptrType}, false)
coroFreeFunc := llvm.AddFunction(c.mod, "llvm.coro.free", coroFreeType)
// Transform all async functions into coroutines.
for _, f := range asyncList {
if f == sleep || f == chanSendStub || f == chanRecvStub {
continue
}
frame := asyncFuncs[f]
frame.cleanupBlock = c.ctx.AddBasicBlock(f, "task.cleanup")
frame.suspendBlock = c.ctx.AddBasicBlock(f, "task.suspend")
frame.unreachableBlock = c.ctx.AddBasicBlock(f, "task.unreachable")
// Scan for async calls and return instructions that need to have
// suspend points inserted.
var asyncCalls []llvm.Value
var returns []llvm.Value
for bb := f.EntryBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) {
if !inst.IsACallInst().IsNil() {
callee := inst.CalledValue()
if _, ok := asyncFuncs[callee]; !ok || callee == sleep || callee == chanSendStub || callee == chanRecvStub {
continue
}
asyncCalls = append(asyncCalls, inst)
} else if !inst.IsAReturnInst().IsNil() {
returns = append(returns, inst)
}
}
}
// Coroutine setup.
c.builder.SetInsertPointBefore(f.EntryBasicBlock().FirstInstruction())
taskState := c.builder.CreateAlloca(c.mod.GetTypeByName("runtime.taskState"), "task.state")
stateI8 := c.builder.CreateBitCast(taskState, c.i8ptrType, "task.state.i8")
id := c.builder.CreateCall(coroIdFunc, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
stateI8,
llvm.ConstNull(c.i8ptrType),
llvm.ConstNull(c.i8ptrType),
}, "task.token")
size := c.builder.CreateCall(coroSizeFunc, nil, "task.size")
if c.targetData.TypeAllocSize(size.Type()) > c.targetData.TypeAllocSize(c.uintptrType) {
size = c.builder.CreateTrunc(size, c.uintptrType, "task.size.uintptr")
} else if c.targetData.TypeAllocSize(size.Type()) < c.targetData.TypeAllocSize(c.uintptrType) {
size = c.builder.CreateZExt(size, c.uintptrType, "task.size.uintptr")
}
data := c.createRuntimeCall("alloc", []llvm.Value{size}, "task.data")
frame.taskHandle = c.builder.CreateCall(coroBeginFunc, []llvm.Value{id, data}, "task.handle")
// Modify async calls so this function suspends right after the child
// returns, because the child is probably not finished yet. Wait until
// the child reactivates the parent.
for _, inst := range asyncCalls {
inst.SetOperand(inst.OperandsCount()-2, frame.taskHandle)
// Split this basic block.
await := c.splitBasicBlock(inst, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.await")
// Set task state to TASK_STATE_CALL.
c.builder.SetInsertPointAtEnd(inst.InstructionParent())
// Suspend.
continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
llvm.ConstNull(c.ctx.TokenType()),
llvm.ConstInt(c.ctx.Int1Type(), 0, false),
}, "")
sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), await)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
}
// Replace return instructions with suspend points that should
// reactivate the parent coroutine.
for _, inst := range returns {
if inst.OperandsCount() == 0 {
// These properties were added by the functionattrs pass.
// Remove them, because now we start using the parameter.
// https://llvm.org/docs/Passes.html#functionattrs-deduce-function-attributes
for _, kind := range []string{"nocapture", "readnone"} {
kindID := llvm.AttributeKindID(kind)
f.RemoveEnumAttributeAtIndex(f.ParamsCount(), kindID)
}
// Reactivate the parent coroutine. This adds it back to
// the run queue, so it is started again by the
// scheduler when possible (possibly right after the
// following suspend).
c.builder.SetInsertPointBefore(inst)
parentHandle := f.LastParam()
c.createRuntimeCall("activateTask", []llvm.Value{parentHandle}, "")
// Suspend this coroutine.
// It would look like this is unnecessary, but if this
// suspend point is left out, it leads to undefined
// behavior somehow (with the unreachable instruction).
continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
llvm.ConstNull(c.ctx.TokenType()),
llvm.ConstInt(c.ctx.Int1Type(), 1, false),
}, "ret")
sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), frame.unreachableBlock)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
inst.EraseFromParentAsInstruction()
} else {
panic("todo: return value from coroutine")
}
}
// Coroutine cleanup. Free resources associated with this coroutine.
c.builder.SetInsertPointAtEnd(frame.cleanupBlock)
mem := c.builder.CreateCall(coroFreeFunc, []llvm.Value{id, frame.taskHandle}, "task.data.free")
c.createRuntimeCall("free", []llvm.Value{mem}, "")
c.builder.CreateBr(frame.suspendBlock)
// Coroutine suspend. A call to llvm.coro.suspend() will branch here.
c.builder.SetInsertPointAtEnd(frame.suspendBlock)
c.builder.CreateCall(coroEndFunc, []llvm.Value{frame.taskHandle, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "unused")
returnType := f.Type().ElementType().ReturnType()
if returnType.TypeKind() == llvm.VoidTypeKind {
c.builder.CreateRetVoid()
} else {
c.builder.CreateRet(llvm.Undef(returnType))
}
// Coroutine exit. All final suspends (return instructions) will branch
// here.
c.builder.SetInsertPointAtEnd(frame.unreachableBlock)
c.builder.CreateUnreachable()
}
// Transform calls to time.Sleep() into coroutine suspend points.
for _, sleepCall := range getUses(sleep) {
// sleepCall must be a call instruction.
frame := asyncFuncs[sleepCall.InstructionParent().Parent()]
duration := sleepCall.Operand(0)
// Set task state to TASK_STATE_SLEEP and set the duration.
c.builder.SetInsertPointBefore(sleepCall)
c.createRuntimeCall("sleepTask", []llvm.Value{frame.taskHandle, duration}, "")
// Yield to scheduler.
continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
llvm.ConstNull(c.ctx.TokenType()),
llvm.ConstInt(c.ctx.Int1Type(), 0, false),
}, "")
wakeup := c.splitBasicBlock(sleepCall, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.wakeup")
c.builder.SetInsertPointBefore(sleepCall)
sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
sleepCall.EraseFromParentAsInstruction()
}
// Transform calls to runtime.chanSendStub into channel send operations.
for _, sendOp := range getUses(chanSendStub) {
// sendOp must be a call instruction.
frame := asyncFuncs[sendOp.InstructionParent().Parent()]
// Send the value over the channel, or block.
sendOp.SetOperand(0, frame.taskHandle)
sendOp.SetOperand(sendOp.OperandsCount()-1, c.mod.NamedFunction("runtime.chanSend"))
// Use taskState.data to store the value to send:
// *(*valueType)(&coroutine.promise().data) = valueToSend
// runtime.chanSend(coroutine, ch)
bitcast := sendOp.Operand(2)
valueAlloca := bitcast.Operand(0)
c.builder.SetInsertPointBefore(valueAlloca)
promiseType := c.mod.GetTypeByName("runtime.taskState")
promiseRaw := c.builder.CreateCall(coroPromiseFunc, []llvm.Value{
frame.taskHandle,
llvm.ConstInt(c.ctx.Int32Type(), uint64(c.targetData.PrefTypeAlignment(promiseType)), false),
llvm.ConstInt(c.ctx.Int1Type(), 0, false),
}, "task.promise.raw")
promise := c.builder.CreateBitCast(promiseRaw, llvm.PointerType(promiseType, 0), "task.promise")
dataPtr := c.builder.CreateGEP(promise, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
llvm.ConstInt(c.ctx.Int32Type(), 2, false),
}, "task.promise.data")
sendOp.SetOperand(2, llvm.Undef(c.i8ptrType))
valueAlloca.ReplaceAllUsesWith(c.builder.CreateBitCast(dataPtr, valueAlloca.Type(), ""))
bitcast.EraseFromParentAsInstruction()
valueAlloca.EraseFromParentAsInstruction()
// Yield to scheduler.
c.builder.SetInsertPointBefore(llvm.NextInstruction(sendOp))
continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
llvm.ConstNull(c.ctx.TokenType()),
llvm.ConstInt(c.ctx.Int1Type(), 0, false),
}, "")
sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
wakeup := c.splitBasicBlock(sw, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.sent")
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
}
// Transform calls to runtime.chanRecvStub into channel receive operations.
for _, recvOp := range getUses(chanRecvStub) {
// recvOp must be a call instruction.
frame := asyncFuncs[recvOp.InstructionParent().Parent()]
bitcast := recvOp.Operand(2)
commaOk := recvOp.Operand(3)
valueAlloca := bitcast.Operand(0)
// Receive the value over the channel, or block.
recvOp.SetOperand(0, frame.taskHandle)
recvOp.SetOperand(recvOp.OperandsCount()-1, c.mod.NamedFunction("runtime.chanRecv"))
recvOp.SetOperand(2, llvm.Undef(c.i8ptrType))
bitcast.EraseFromParentAsInstruction()
// Yield to scheduler.
c.builder.SetInsertPointBefore(llvm.NextInstruction(recvOp))
continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
llvm.ConstNull(c.ctx.TokenType()),
llvm.ConstInt(c.ctx.Int1Type(), 0, false),
}, "")
sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
wakeup := c.splitBasicBlock(sw, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.received")
c.builder.SetInsertPointAtEnd(recvOp.InstructionParent())
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup)
sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
// The value to receive is stored in taskState.data:
// runtime.chanRecv(coroutine, ch)
// promise := coroutine.promise()
// valueReceived := *(*valueType)(&promise.data)
// ok := promise.commaOk
c.builder.SetInsertPointBefore(wakeup.FirstInstruction())
promiseType := c.mod.GetTypeByName("runtime.taskState")
promiseRaw := c.builder.CreateCall(coroPromiseFunc, []llvm.Value{
frame.taskHandle,
llvm.ConstInt(c.ctx.Int32Type(), uint64(c.targetData.PrefTypeAlignment(promiseType)), false),
llvm.ConstInt(c.ctx.Int1Type(), 0, false),
}, "task.promise.raw")
promise := c.builder.CreateBitCast(promiseRaw, llvm.PointerType(promiseType, 0), "task.promise")
dataPtr := c.builder.CreateGEP(promise, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
llvm.ConstInt(c.ctx.Int32Type(), 2, false),
}, "task.promise.data")
valueAlloca.ReplaceAllUsesWith(c.builder.CreateBitCast(dataPtr, valueAlloca.Type(), ""))
valueAlloca.EraseFromParentAsInstruction()
commaOkPtr := c.builder.CreateGEP(promise, []llvm.Value{
llvm.ConstInt(c.ctx.Int32Type(), 0, false),
llvm.ConstInt(c.ctx.Int32Type(), 1, false),
}, "task.promise.comma-ok")
commaOk.ReplaceAllUsesWith(commaOkPtr)
recvOp.SetOperand(3, llvm.Undef(commaOk.Type()))
}
return true, c.lowerMakeGoroutineCalls()
}
// Lower runtime.makeGoroutine calls to regular call instructions. This is done
// after the regular goroutine transformations. The started goroutines are
// either non-blocking (in which case they can be called directly) or blocking,
// in which case they will ask the scheduler themselves to be rescheduled.
func (c *Compiler) lowerMakeGoroutineCalls() error {
// The following Go code:
// go startedGoroutine()
//
// Is translated to the following during IR construction, to preserve the
// fact that this function should be called as a new goroutine.
// %0 = call i8* @runtime.makeGoroutine(i8* bitcast (void (i8*, i8*)* @main.startedGoroutine to i8*), i8* undef, i8* null)
// %1 = bitcast i8* %0 to void (i8*, i8*)*
// call void %1(i8* undef, i8* undef)
//
// This function rewrites it to a direct call:
// call void @main.startedGoroutine(i8* undef, i8* null)
makeGoroutine := c.mod.NamedFunction("runtime.makeGoroutine")
for _, goroutine := range getUses(makeGoroutine) {
bitcastIn := goroutine.Operand(0)
origFunc := bitcastIn.Operand(0)
uses := getUses(goroutine)
if len(uses) != 1 || uses[0].IsABitCastInst().IsNil() {
return errors.New("expected exactly 1 bitcast use of runtime.makeGoroutine")
}
bitcastOut := uses[0]
uses = getUses(bitcastOut)
if len(uses) != 1 || uses[0].IsACallInst().IsNil() {
return errors.New("expected exactly 1 call use of runtime.makeGoroutine bitcast")
}
realCall := uses[0]
// Create call instruction.
var params []llvm.Value
for i := 0; i < realCall.OperandsCount()-1; i++ {
params = append(params, realCall.Operand(i))
}
params[len(params)-1] = llvm.ConstPointerNull(c.i8ptrType) // parent coroutine handle (must be nil)
c.builder.SetInsertPointBefore(realCall)
c.builder.CreateCall(origFunc, params, "")
realCall.EraseFromParentAsInstruction()
bitcastOut.EraseFromParentAsInstruction()
goroutine.EraseFromParentAsInstruction()
}
return nil
}