Skip to content

Commit 008bd40

Browse files
committed
[query] Configure Optimiser via Flags, not HailContext
1 parent 7b2b24e commit 008bd40

32 files changed

+255
-257
lines changed

hail/hail/src/is/hail/HailContext.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,19 @@ object HailContext {
9292
}
9393
}
9494

95-
def getOrCreate(backend: Backend, branchingFactor: Int = 50, optimizerIterations: Int = 3)
96-
: HailContext = {
95+
def getOrCreate(backend: Backend, branchingFactor: Int = 50): HailContext = {
9796
if (theContext == null)
98-
return HailContext(backend, branchingFactor, optimizerIterations)
97+
return HailContext(backend, branchingFactor)
9998

10099
if (theContext.branchingFactor != branchingFactor)
101100
warn(
102101
s"Requested branchingFactor $branchingFactor, but already initialized to ${theContext.branchingFactor}. Ignoring requested setting."
103102
)
104103

105-
if (theContext.optimizerIterations != optimizerIterations)
106-
warn(
107-
s"Requested optimizerIterations $optimizerIterations, but already initialized to ${theContext.optimizerIterations}. Ignoring requested setting."
108-
)
109-
110104
theContext
111105
}
112106

113-
def apply(backend: Backend, branchingFactor: Int = 50, optimizerIterations: Int = 3)
114-
: HailContext = synchronized {
107+
def apply(backend: Backend, branchingFactor: Int = 50): HailContext = synchronized {
115108
require(theContext == null)
116109
checkJavaVersion()
117110

@@ -129,7 +122,7 @@ object HailContext {
129122
)
130123
}
131124

132-
theContext = new HailContext(backend, branchingFactor, optimizerIterations)
125+
theContext = new HailContext(backend, branchingFactor)
133126

134127
info(s"Running Hail version ${theContext.version}")
135128

@@ -173,7 +166,6 @@ object HailContext {
173166
class HailContext private (
174167
var backend: Backend,
175168
val branchingFactor: Int,
176-
val optimizerIterations: Int,
177169
) {
178170
def stop(): Unit = HailContext.stop()
179171

hail/hail/src/is/hail/HailFeatureFlags.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package is.hail
22

33
import is.hail.backend.ExecutionCache
44
import is.hail.backend.spark.SparkBackend
5+
import is.hail.expr.ir.Optimize
56
import is.hail.io.fs.RequesterPaysConfig
67
import is.hail.types.encoded.EType
78
import is.hail.utils._
@@ -45,6 +46,8 @@ object HailFeatureFlags {
4546
SparkBackend.Flags.MaxStageParallelism,
4647
"HAIL_SPARK_MAX_STAGE_PARALLELISM" -> Integer.MAX_VALUE.toString,
4748
),
49+
(Optimize.Flags.Optimize, "HAIL_QUERY_OPTIMIZE" -> "1"),
50+
(Optimize.Flags.MaxOptimizerIterations, "HAIL_OPTIMIZER_ITERATIONS" -> null),
4851
)
4952

5053
def fromEnv(m: Map[String, String] = sys.env): HailFeatureFlags =
@@ -58,7 +61,7 @@ object HailFeatureFlags {
5861
}
5962

6063
class HailFeatureFlags private (
61-
val flags: mutable.Map[String, String]
64+
private[this] val flags: mutable.Map[String, String]
6265
) extends Serializable {
6366
val available: java.util.ArrayList[String] =
6467
new java.util.ArrayList[String](java.util.Arrays.asList[String](flags.keys.toSeq: _*))
@@ -71,11 +74,20 @@ class HailFeatureFlags private (
7174
def +(feature: (String, String)): HailFeatureFlags =
7275
new HailFeatureFlags(flags + (feature._1 -> feature._2))
7376

77+
def define(feature: String): HailFeatureFlags =
78+
new HailFeatureFlags(flags + (feature -> "1"))
79+
80+
def -(feature: String): HailFeatureFlags =
81+
new HailFeatureFlags(flags - feature)
82+
7483
def get(flag: String): String = flags(flag)
7584

7685
def lookup(flag: String): Option[String] =
7786
Option(flags(flag)).filter(_.nonEmpty)
7887

88+
def isDefined(flag: String): Boolean =
89+
lookup(flag).isDefined
90+
7991
def exists(flag: String): Boolean = flags.contains(flag)
8092

8193
def toJSONEnv: JArray =

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ object LocalBackend extends Backend {
106106
print: Option[PrintWriter] = None,
107107
): Either[Unit, (PTuple, Long)] =
108108
ctx.time {
109-
val ir = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All)(ctx, ir0).asInstanceOf[IR]
109+
val ir = LoweringPipeline.darrayLowerer(DArrayLowering.All)(ctx, ir0).asInstanceOf[IR]
110110

111111
if (!Compilable(ir))
112112
throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}")

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class ServiceBackend(
293293

294294
private[this] def _jvmLowerAndExecute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] =
295295
ctx.time {
296-
val x = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All)(ctx, ir).asInstanceOf[IR]
296+
val x = LoweringPipeline.darrayLowerer(DArrayLowering.All)(ctx, ir).asInstanceOf[IR]
297297

298298
x.typ match {
299299
case TVoid =>
@@ -303,7 +303,6 @@ class ServiceBackend(
303303
FastSeq[TypeInfo[_]](classInfo[Region]),
304304
UnitInfo,
305305
x,
306-
optimize = true,
307306
)
308307

309308
Left(ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r)(r)))
@@ -315,7 +314,6 @@ class ServiceBackend(
315314
FastSeq(classInfo[Region]),
316315
LongInfo,
317316
MakeTuple.ordered(FastSeq(x)),
318-
optimize = true,
319317
)
320318

321319
Right((pt, ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r)(r))))

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,20 +386,18 @@ class SparkBackend(val sc: SparkContext) extends Backend {
386386
def jvmLowerAndExecute(
387387
ctx: ExecuteContext,
388388
ir0: IR,
389-
optimize: Boolean,
390389
lowerTable: Boolean,
391390
lowerBM: Boolean,
392391
print: Option[PrintWriter] = None,
393392
): Any =
394-
_jvmLowerAndExecute(ctx, ir0, optimize, lowerTable, lowerBM, print) match {
393+
_jvmLowerAndExecute(ctx, ir0, lowerTable, lowerBM, print) match {
395394
case Left(x) => x
396395
case Right((pt, off)) => SafeRow(pt, off).get(0)
397396
}
398397

399398
private[this] def _jvmLowerAndExecute(
400399
ctx: ExecuteContext,
401400
ir0: IR,
402-
optimize: Boolean,
403401
lowerTable: Boolean,
404402
lowerBM: Boolean,
405403
print: Option[PrintWriter] = None,
@@ -411,7 +409,7 @@ class SparkBackend(val sc: SparkContext) extends Backend {
411409
case (false, true) => DArrayLowering.BMOnly
412410
case (false, false) => throw new LowererUnsupportedOperation("no lowering enabled")
413411
}
414-
val ir = LoweringPipeline.darrayLowerer(optimize)(typesToLower)(ctx, ir0).asInstanceOf[IR]
412+
val ir = LoweringPipeline.darrayLowerer(typesToLower)(ctx, ir0).asInstanceOf[IR]
415413

416414
if (!Compilable(ir))
417415
throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}")
@@ -451,11 +449,11 @@ class SparkBackend(val sc: SparkContext) extends Backend {
451449
try {
452450
val lowerTable = ctx.flags.get("lower") != null
453451
val lowerBM = ctx.flags.get("lower_bm") != null
454-
_jvmLowerAndExecute(ctx, ir, optimize = true, lowerTable, lowerBM)
452+
_jvmLowerAndExecute(ctx, ir, lowerTable, lowerBM)
455453
} catch {
456454
case e: LowererUnsupportedOperation if ctx.flags.get("lower_only") != null => throw e
457455
case _: LowererUnsupportedOperation =>
458-
CompileAndEvaluate._apply(ctx, ir, optimize = true)
456+
CompileAndEvaluate._apply(ctx, ir)
459457
}
460458
}
461459

hail/hail/src/is/hail/expr/ir/BlockMatrixIR.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ case class ValueToBlockMatrix(
10441044
override protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = {
10451045
val IndexedSeq(nRows, nCols) = shape
10461046
BlockMatrixIR.checkFitsIntoArray(nRows, nCols)
1047-
CompileAndEvaluate[Any](ctx, child, true) match {
1047+
CompileAndEvaluate[Any](ctx, child) match {
10481048
case scalar: Double =>
10491049
assert(nRows == 1 && nCols == 1)
10501050
BlockMatrix.fill(nRows, nCols, scalar, blockSize)

hail/hail/src/is/hail/expr/ir/Compile.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ object compile {
101101
body,
102102
BindingEnv(Env.fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In(i, t) })),
103103
)
104-
ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx)
104+
ir = LoweringPipeline.compileLowerer(ctx, ir).asInstanceOf[IR].noSharing(ctx)
105105
TypeCheck(ctx, ir)
106106

107107
val fb = EmitFunctionBuilder[F](
@@ -207,7 +207,7 @@ object CompileIterator {
207207

208208
val outerRegion = outerRegionField
209209

210-
val ir = LoweringPipeline.compileLowerer(true)(ctx, body).asInstanceOf[IR].noSharing(ctx)
210+
val ir = LoweringPipeline.compileLowerer(ctx, body).asInstanceOf[IR].noSharing(ctx)
211211
TypeCheck(ctx, ir)
212212

213213
var elementAddress: Settable[Long] = null

hail/hail/src/is/hail/expr/ir/CompileAndEvaluate.scala

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,19 @@ import is.hail.utils.FastSeq
1414
import org.apache.spark.sql.Row
1515

1616
object CompileAndEvaluate {
17-
def apply[T](ctx: ExecuteContext, ir0: IR, optimize: Boolean = true): T = {
17+
def apply[T](ctx: ExecuteContext, ir0: IR): T =
1818
ctx.time {
19-
_apply(ctx, ir0, optimize) match {
19+
_apply(ctx, ir0) match {
2020
case Left(()) => ().asInstanceOf[T]
2121
case Right((t, off)) => SafeRow(t, off).getAs[T](0)
2222
}
2323
}
24-
}
2524

26-
def evalToIR(ctx: ExecuteContext, ir0: IR, optimize: Boolean = true): IR = {
25+
def evalToIR(ctx: ExecuteContext, ir0: IR): IR = {
2726
if (IsConstant(ir0))
2827
return ir0
2928

30-
_apply(ctx, ir0, optimize) match {
29+
_apply(ctx, ir0) match {
3130
case Left(_) => Begin(FastSeq())
3231
case Right((pt, addr)) =>
3332
ir0.typ match {
@@ -39,13 +38,9 @@ object CompileAndEvaluate {
3938
}
4039
}
4140

42-
def _apply(
43-
ctx: ExecuteContext,
44-
ir0: IR,
45-
optimize: Boolean = true,
46-
): Either[Unit, (PTuple, Long)] =
41+
def _apply(ctx: ExecuteContext, ir0: IR): Either[Unit, (PTuple, Long)] =
4742
ctx.time {
48-
val ir = LoweringPipeline.relationalLowerer(optimize)(ctx, ir0).asInstanceOf[IR]
43+
val ir = LoweringPipeline.relationalLowerer(ctx, ir0).asInstanceOf[IR]
4944

5045
ir.typ match {
5146
case TVoid =>
@@ -56,7 +51,6 @@ object CompileAndEvaluate {
5651
UnitInfo,
5752
ir,
5853
print = None,
59-
optimize = optimize,
6054
)
6155

6256
val unit: Unit = ctx.scopedExecution { (hcl, fs, htc, r) =>
@@ -75,7 +69,6 @@ object CompileAndEvaluate {
7569
LongInfo,
7670
MakeTuple.ordered(FastSeq(ir)),
7771
print = None,
78-
optimize = optimize,
7972
)
8073

8174
val res = ctx.scopedExecution { (hcl, fs, htc, r) =>

hail/hail/src/is/hail/expr/ir/Interpret.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,22 @@ import org.apache.spark.sql.Row
2424
object Interpret {
2525
type Agg = (IndexedSeq[Row], TStruct)
2626

27-
def apply(tir: TableIR, ctx: ExecuteContext): TableValue =
28-
apply(tir, ctx, optimize = true)
29-
30-
def apply(tir: TableIR, ctx: ExecuteContext, optimize: Boolean): TableValue = {
27+
def apply(tir: TableIR, ctx: ExecuteContext): TableValue = {
3128
val lowered =
32-
LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing(
29+
LoweringPipeline.legacyRelationalLowerer(ctx, tir).asInstanceOf[TableIR].noSharing(
3330
ctx
3431
)
3532
ExecuteRelational(ctx, lowered).asTableValue(ctx)
3633
}
3734

38-
def apply(mir: MatrixIR, ctx: ExecuteContext, optimize: Boolean): TableValue = {
39-
val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, mir).asInstanceOf[TableIR]
35+
def apply(mir: MatrixIR, ctx: ExecuteContext): TableValue = {
36+
val lowered = LoweringPipeline.legacyRelationalLowerer(ctx, mir).asInstanceOf[TableIR]
4037
ExecuteRelational(ctx, lowered).asTableValue(ctx)
4138
}
4239

43-
def apply(bmir: BlockMatrixIR, ctx: ExecuteContext, optimize: Boolean): BlockMatrix = {
40+
def apply(bmir: BlockMatrixIR, ctx: ExecuteContext): BlockMatrix = {
4441
val lowered =
45-
LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, bmir).asInstanceOf[BlockMatrixIR]
42+
LoweringPipeline.legacyRelationalLowerer(ctx, bmir).asInstanceOf[BlockMatrixIR]
4643
lowered.execute(ctx)
4744
}
4845

@@ -54,13 +51,12 @@ object Interpret {
5451
ir0: IR,
5552
env: Env[(Any, Type)],
5653
args: IndexedSeq[(Any, Type)],
57-
optimize: Boolean = true,
5854
): T = {
5955
val bindings = env.m.view.map { case (k, (value, t)) =>
6056
k -> Literal.coerce(t, value)
6157
}.toFastSeq
6258
val lowered =
63-
LoweringPipeline.relationalLowerer(optimize).apply(ctx, Let(bindings, ir0)).asInstanceOf[IR]
59+
LoweringPipeline.relationalLowerer(ctx, Let(bindings, ir0)).asInstanceOf[IR]
6460
val result = run(ctx, lowered, Env.empty[Any], args, Memo.empty).asInstanceOf[T]
6561
result
6662
}

hail/hail/src/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ object LowerOrInterpretNonCompilable {
2828
val fullyLowered = LowerToDistributedArrayPass(DArrayLowering.All).transform(ctx, value)
2929
.asInstanceOf[IR]
3030
log.info(s"compiling and evaluating result: ${value.getClass.getSimpleName}")
31-
CompileAndEvaluate.evalToIR(ctx, fullyLowered, true)
31+
CompileAndEvaluate.evalToIR(ctx, fullyLowered)
3232
}
3333
log.info(s"took ${formatTime(System.nanoTime() - preTime)}")
3434
assert(result.typ == value.typ)

0 commit comments

Comments
 (0)