Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler: refactor scheduling_algo #3983

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 51 additions & 41 deletions internal/scheduler/scheduling/scheduling_algo.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,36 +216,63 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con
func(queue *api.Queue) string { return queue.Name },
func(queue *api.Queue) *api.Queue { return queue })

allKnownPools := map[string]bool{}
nodeFactory := internaltypes.NewNodeFactory(
l.schedulingConfig.IndexedTaints,
l.schedulingConfig.IndexedNodeLabels,
l.resourceListFactory,
)

nodes := []*internaltypes.Node{}
for _, executor := range executors {
for _, node := range executor.Nodes {
allKnownPools[node.GetPool()] = true
if executor.Id != node.Executor {
ctx.Errorf("Executor name mismatch: %q != %q", node.Executor, executor.Id)
continue
}
itNode, err := nodeFactory.FromSchedulerObjectsNode(node)
if err != nil {
ctx.Errorf("Invalid node %s: %v", node.Name, err)
continue
}
nodes = append(nodes, itNode)
}
}

jobSchedulingInfo, err := calculateJobSchedulingInfo(ctx, executors, queueByName, txn.GetAll(), maps.Keys(allKnownPools))
allKnownPools := map[string]bool{}
for _, node := range nodes {
allKnownPools[node.GetPool()] = true
}

jobSchedulingInfo, err := calculateJobSchedulingInfo(ctx,
armadamaps.FromSlice(executors,
func(ex *schedulerobjects.Executor) string { return ex.Id },
func(_ *schedulerobjects.Executor) bool { return true }),
queueByName,
txn.GetAll(),
maps.Keys(allKnownPools))
if err != nil {
return nil, err
}

// Filter out any executor that isn't acknowledging jobs in a timely fashion
// Note that we do this after aggregating allocation across clusters for fair share.
healthyExecutors := l.filterLaggingExecutors(ctx, executors, jobSchedulingInfo.jobsByExecutorId)
nodes := []*schedulerobjects.Node{}
for _, executor := range healthyExecutors {
nodes = append(nodes, executor.Nodes...)
}
laggingExecutors := l.getLaggingExecutors(ctx, executors, jobSchedulingInfo.jobsByExecutorId)

pools := make(map[string]*poolSchedulingInfo, len(allKnownPools))
for _, pool := range maps.Keys(allKnownPools) {
nodeDb, err := l.constructNodeDb(jobSchedulingInfo.jobsByPool[pool], armadaslices.Filter(nodes, func(node *schedulerobjects.Node) bool { return node.Pool == pool }))
nodeDb, err := l.constructNodeDb(jobSchedulingInfo.jobsByPool[pool], armadaslices.Filter(nodes,
func(node *internaltypes.Node) bool {
return node.GetPool() == pool && !laggingExecutors[node.GetExecutor()]
}))
if err != nil {
return nil, err
}

totalPoolResources := l.getCapacityForPool(pool, executors)

schedulingContext, err := l.constructSchedulingContext(
pool,
l.getCapacityForPool(pool, executors),
totalPoolResources,
jobSchedulingInfo.demandByPoolByQueue[pool],
jobSchedulingInfo.allocatedByPoolAndQueueAndPriorityClass[pool],
queueByName)
Expand All @@ -257,7 +284,7 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con
name: pool,
nodeDb: nodeDb,
schedulingContext: schedulingContext,
totalCapacity: l.getCapacityForPool(pool, executors),
totalCapacity: totalPoolResources,
}
}

Expand Down Expand Up @@ -294,14 +321,9 @@ type jobSchedulingInfo struct {
allocatedByPoolAndQueueAndPriorityClass map[string]map[string]schedulerobjects.QuantityByTAndResourceType[string]
}

func calculateJobSchedulingInfo(ctx *armadacontext.Context, executors []*schedulerobjects.Executor,
func calculateJobSchedulingInfo(ctx *armadacontext.Context, activeExecutorsSet map[string]bool,
queues map[string]*api.Queue, jobs []*jobdb.Job, allPools []string,
) (*jobSchedulingInfo, error) {
activeExecutorsSet := map[string]bool{}
for _, executor := range executors {
activeExecutorsSet[executor.Id] = true
}

jobsByExecutorId := make(map[string][]*jobdb.Job)
jobsByPool := make(map[string][]*jobdb.Job)
nodeIdByJobId := make(map[string]string)
Expand Down Expand Up @@ -416,13 +438,7 @@ func calculateJobSchedulingInfo(ctx *armadacontext.Context, executors []*schedul
}, nil
}

func (l *FairSchedulingAlgo) constructNodeDb(jobs []*jobdb.Job, nodes []*schedulerobjects.Node) (*nodedb.NodeDb, error) {
nodeFactory := internaltypes.NewNodeFactory(
l.schedulingConfig.IndexedTaints,
l.schedulingConfig.IndexedNodeLabels,
l.resourceListFactory,
)

func (l *FairSchedulingAlgo) constructNodeDb(jobs []*jobdb.Job, nodes []*internaltypes.Node) (*nodedb.NodeDb, error) {
nodeDb, err := nodedb.NewNodeDb(
l.schedulingConfig.PriorityClasses,
l.schedulingConfig.IndexedResources,
Expand All @@ -434,7 +450,7 @@ func (l *FairSchedulingAlgo) constructNodeDb(jobs []*jobdb.Job, nodes []*schedul
if err != nil {
return nil, err
}
if err := l.populateNodeDb(nodeDb, nodeFactory, jobs, nodes); err != nil {
if err := l.populateNodeDb(nodeDb, jobs, nodes); err != nil {
return nil, err
}

Expand Down Expand Up @@ -550,12 +566,12 @@ func (l *FairSchedulingAlgo) schedulePool(
}

// populateNodeDb adds all the nodes and jobs associated with a particular pool to the nodeDb.
func (l *FairSchedulingAlgo) populateNodeDb(nodeDb *nodedb.NodeDb, nodeFactory *internaltypes.NodeFactory, jobs []*jobdb.Job, nodes []*schedulerobjects.Node) error {
func (l *FairSchedulingAlgo) populateNodeDb(nodeDb *nodedb.NodeDb, jobs []*jobdb.Job, nodes []*internaltypes.Node) error {
txn := nodeDb.Txn(true)
defer txn.Abort()
nodesById := armadaslices.GroupByFuncUnique(
nodes,
func(node *schedulerobjects.Node) string { return node.Id },
func(node *internaltypes.Node) string { return node.GetId() },
)
jobsByNodeId := make(map[string][]*jobdb.Job, len(nodes))
for _, job := range jobs {
Expand All @@ -574,12 +590,7 @@ func (l *FairSchedulingAlgo) populateNodeDb(nodeDb *nodedb.NodeDb, nodeFactory *
}

for _, node := range nodes {
dbNode, err := nodeFactory.FromSchedulerObjectsNode(node)
if err != nil {
return err
}

if err := nodeDb.CreateAndInsertWithJobDbJobsWithTxn(txn, jobsByNodeId[node.Id], dbNode); err != nil {
if err := nodeDb.CreateAndInsertWithJobDbJobsWithTxn(txn, jobsByNodeId[node.GetId()], node); err != nil {
return err
}
}
Expand All @@ -602,16 +613,16 @@ func (l *FairSchedulingAlgo) filterStaleExecutors(ctx *armadacontext.Context, ex
return activeExecutors
}

// filterLaggingExecutors returns all executors with <= l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor unacknowledged jobs,
// getLaggingExecutors returns the names of all executors with > l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor unacknowledged jobs,
// where unacknowledged means the executor has not echoed the job since it was scheduled.
//
// Used to rate-limit scheduling onto executors that can't keep up.
func (l *FairSchedulingAlgo) filterLaggingExecutors(
func (l *FairSchedulingAlgo) getLaggingExecutors(
ctx *armadacontext.Context,
executors []*schedulerobjects.Executor,
leasedJobsByExecutor map[string][]*jobdb.Job,
) []*schedulerobjects.Executor {
activeExecutors := make([]*schedulerobjects.Executor, 0, len(executors))
) map[string]bool {
result := map[string]bool{}
for _, executor := range executors {
leasedJobs := leasedJobsByExecutor[executor.Id]
executorRuns := executor.AllRuns()
Expand All @@ -628,16 +639,15 @@ func (l *FairSchedulingAlgo) filterLaggingExecutors(
}
}
}
if numUnacknowledgedJobs <= l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor {
activeExecutors = append(activeExecutors, executor)
} else {
if numUnacknowledgedJobs > l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor {
ctx.Warnf(
"%d unacknowledged jobs on executor %s exceeds limit of %d; executor will not be considered for scheduling",
numUnacknowledgedJobs, executor.Id, l.schedulingConfig.MaxUnacknowledgedJobsPerExecutor,
)
result[executor.Id] = true
}
}
return activeExecutors
return result
}

// sortGroups sorts the given list of groups based on priorities defined in groupToPriority map.
Expand Down
16 changes: 15 additions & 1 deletion internal/scheduler/scheduling/scheduling_algo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/armadaproject/armada/internal/common/armadacontext"
armadaslices "github.com/armadaproject/armada/internal/common/slices"
"github.com/armadaproject/armada/internal/scheduler/configuration"
"github.com/armadaproject/armada/internal/scheduler/internaltypes"
"github.com/armadaproject/armada/internal/scheduler/jobdb"
schedulermocks "github.com/armadaproject/armada/internal/scheduler/mocks"
"github.com/armadaproject/armada/internal/scheduler/nodedb"
Expand Down Expand Up @@ -628,6 +629,12 @@ func BenchmarkNodeDbConstruction(b *testing.B) {
require.NoError(b, err)
b.StartTimer()

nodeFactory := internaltypes.NewNodeFactory(
schedulingConfig.IndexedTaints,
schedulingConfig.IndexedNodeLabels,
testfixtures.TestResourceListFactory,
)

nodeDb, err := nodedb.NewNodeDb(
schedulingConfig.PriorityClasses,
schedulingConfig.IndexedResources,
Expand All @@ -637,7 +644,14 @@ func BenchmarkNodeDbConstruction(b *testing.B) {
testfixtures.TestResourceListFactory,
)
require.NoError(b, err)
err = algo.populateNodeDb(nodeDb, testfixtures.TestNodeFactory, jobs, nodes)
dbNodes := []*internaltypes.Node{}
for _, node := range nodes {
dbNode, err := nodeFactory.FromSchedulerObjectsNode(node)
require.NoError(b, err)
dbNodes = append(dbNodes, dbNode)
}

err = algo.populateNodeDb(nodeDb, jobs, dbNodes)
require.NoError(b, err)
}
})
Expand Down