Skip to content

Simulate impact of shard movement using shard-level write load #131406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.cluster.ClusterInfo.NodeAndShard;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.WriteLoadPerShardSimulator;
import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
import org.elasticsearch.common.util.CopyOnFirstWriteMap;
import org.elasticsearch.index.shard.ShardId;
Expand All @@ -34,7 +35,7 @@ public class ClusterInfoSimulator {
private final Map<ShardId, Long> shardDataSetSizes;
private final Map<NodeAndShard, String> dataPath;
private final Map<String, EstimatedHeapUsage> estimatedHeapUsages;
private final Map<String, NodeUsageStatsForThreadPools> nodeThreadPoolUsageStats;
private final WriteLoadPerShardSimulator writeLoadPerShardSimulator;

public ClusterInfoSimulator(RoutingAllocation allocation) {
this.allocation = allocation;
Expand All @@ -44,7 +45,7 @@ public ClusterInfoSimulator(RoutingAllocation allocation) {
this.shardDataSetSizes = Map.copyOf(allocation.clusterInfo().shardDataSetSizes);
this.dataPath = Map.copyOf(allocation.clusterInfo().dataPath);
this.estimatedHeapUsages = allocation.clusterInfo().getEstimatedHeapUsages();
this.nodeThreadPoolUsageStats = allocation.clusterInfo().getNodeUsageStatsForThreadPools();
this.writeLoadPerShardSimulator = new WriteLoadPerShardSimulator(allocation);
}

/**
Expand Down Expand Up @@ -115,6 +116,7 @@ public void simulateShardStarted(ShardRouting shard) {
shardSizes.put(shardIdentifierFromRouting(shard), project.getIndexSafe(shard.index()).ignoreDiskWatermarks() ? 0 : size);
}
}
writeLoadPerShardSimulator.simulateShardStarted(shard);
}

private void modifyDiskUsage(String nodeId, long freeDelta) {
Expand Down Expand Up @@ -159,7 +161,7 @@ public ClusterInfo getClusterInfo() {
dataPath,
Map.of(),
estimatedHeapUsages,
nodeThreadPoolUsageStats
writeLoadPerShardSimulator.nodeUsageStatsForThreadPools()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.cluster.routing;

import com.carrotsearch.hppc.ObjectFloatHashMap;
import com.carrotsearch.hppc.ObjectFloatMap;

import org.elasticsearch.cluster.NodeUsageStatsForThreadPools;
import org.elasticsearch.cluster.metadata.IndexAbstraction;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class WriteLoadPerShardSimulator {

private final ObjectFloatMap<String> writeLoadDeltas;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: simulatedNodesLoad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I changed to simulatedWriteLoadDeltas, we only store the delta from the reported/original write load here. The idea there is that if no delta is present, we can just return the original NodeUsageStatsForThreadPools instance.

private final RoutingAllocation routingAllocation;
private final ObjectFloatMap<ShardId> writeLoadsPerShard;

public WriteLoadPerShardSimulator(RoutingAllocation routingAllocation) {
this.routingAllocation = routingAllocation;
this.writeLoadDeltas = new ObjectFloatHashMap<>();
writeLoadsPerShard = estimateWriteLoadsPerShard(routingAllocation);
}

public void simulateShardStarted(ShardRouting shardRouting) {
final float writeLoadForShard = writeLoadsPerShard.get(shardRouting.shardId());
if (writeLoadForShard > 0.0) {
if (shardRouting.relocatingNodeId() != null) {
// relocating
writeLoadDeltas.addTo(shardRouting.relocatingNodeId(), -1 * writeLoadForShard);
writeLoadDeltas.addTo(shardRouting.currentNodeId(), writeLoadForShard);
} else {
// not sure how this would come about, perhaps when allocating a replica after a delay?
writeLoadDeltas.addTo(shardRouting.currentNodeId(), writeLoadForShard);
}
}
}

public Map<String, NodeUsageStatsForThreadPools> nodeUsageStatsForThreadPools() {
return routingAllocation.clusterInfo()
.getNodeUsageStatsForThreadPools()
.entrySet()
.stream()
.collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, e -> {
if (writeLoadDeltas.containsKey(e.getKey())) {
return new NodeUsageStatsForThreadPools(
e.getKey(),
Maps.copyMapWithAddedOrReplacedEntry(
e.getValue().threadPoolUsageStatsMap(),
"write",
replaceWritePoolStats(e.getValue(), writeLoadDeltas.get(e.getKey()))
)
);
}
return e.getValue();
}));
}

private NodeUsageStatsForThreadPools.ThreadPoolUsageStats replaceWritePoolStats(
NodeUsageStatsForThreadPools value,
float writeLoadDelta
) {
final NodeUsageStatsForThreadPools.ThreadPoolUsageStats writeThreadPoolStats = value.threadPoolUsageStatsMap()
.get(ThreadPool.Names.WRITE);
return new NodeUsageStatsForThreadPools.ThreadPoolUsageStats(
writeThreadPoolStats.totalThreadPoolThreads(),
writeThreadPoolStats.averageThreadPoolUtilization() + (writeLoadDelta / writeThreadPoolStats.totalThreadPoolThreads()),
writeThreadPoolStats.averageThreadPoolQueueLatencyMillis()
);
}

// Everything below this line can probably go once we are publishing shard-write-load estimates to the master

private static ObjectFloatMap<ShardId> estimateWriteLoadsPerShard(RoutingAllocation allocation) {
final Map<ShardId, Average> writeLoadPerShard = new HashMap<>();
final Set<String> writeIndexNames = getWriteIndexNames(allocation);
final Map<String, NodeUsageStatsForThreadPools> nodeUsageStatsForThreadPools = allocation.clusterInfo()
.getNodeUsageStatsForThreadPools();
for (final Map.Entry<String, NodeUsageStatsForThreadPools> usageStatsForThreadPoolsEntry : nodeUsageStatsForThreadPools
.entrySet()) {
final NodeUsageStatsForThreadPools value = usageStatsForThreadPoolsEntry.getValue();
final NodeUsageStatsForThreadPools.ThreadPoolUsageStats writeThreadPoolStats = value.threadPoolUsageStatsMap()
.get(ThreadPool.Names.WRITE);
if (writeThreadPoolStats == null) {
// No stats from this node yet
continue;
}
float writeUtilisation = writeThreadPoolStats.averageThreadPoolUtilization() * writeThreadPoolStats.totalThreadPoolThreads();

final String nodeId = usageStatsForThreadPoolsEntry.getKey();
final RoutingNode node = allocation.routingNodes().node(nodeId);
final Set<ShardId> writeShardsOnNode = new HashSet<>();
for (final ShardRouting shardRouting : node) {
if (shardRouting.role() != ShardRouting.Role.SEARCH_ONLY && writeIndexNames.contains(shardRouting.index().getName())) {
writeShardsOnNode.add(shardRouting.shardId());
}
}
writeShardsOnNode.forEach(
shardId -> writeLoadPerShard.computeIfAbsent(shardId, k -> new Average()).add(writeUtilisation / writeShardsOnNode.size())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you equally divide write-load across all write shards on node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is just a stop-gap until we get actual shard loads, which should work as a drop-in replacement.

Copy link
Contributor

@mhl-b mhl-b Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I was thinking maybe we should have some heuristic from already available data. Otherwise signal/noise ratio is too high. It's not uncommon to have hundreds of shards, and estimation has little to no impact on a single shard.

For example use shardSize heuristic, the larger size more likely it would have write-load. Lets say linearly increase weight of those shards as size approaches 15GB. And then decrease weight as they approach to 30GB since we would roll-over them (most of the time) if size <15GB then size/15GB else max(0, 1-size/30GB)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have actual shard write loads shortly. Hopefully we can avoid all this guessing entirely.

#131496

);
}
final ObjectFloatMap<ShardId> writeLoads = new ObjectFloatHashMap<>(writeLoadPerShard.size());
writeLoadPerShard.forEach((shardId, average) -> writeLoads.put(shardId, average.get()));
return writeLoads;
}

private static Set<String> getWriteIndexNames(RoutingAllocation allocation) {
return allocation.metadata()
.projects()
.values()
.stream()
.map(ProjectMetadata::getIndicesLookup)
.flatMap(indicesLookup -> indicesLookup.values().stream())
.map(IndexAbstraction::getWriteIndex)
.filter(Objects::nonNull)
.map(Index::getName)
.collect(Collectors.toUnmodifiableSet());
}

private static final class Average {
int count;
float sum;

public void add(float value) {
count++;
sum += value;
}

public float get() {
return sum / count;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.cluster.routing;

import org.elasticsearch.action.support.replication.ClusterStateCreationUtils;
import org.elasticsearch.cluster.ClusterInfo;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.NodeUsageStatsForThreadPools;
import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.StreamSupport;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.sameInstance;

public class WriteLoadPerShardSimulatorTests extends ESTestCase {

private static final RoutingChangesObserver NOOP = new RoutingChangesObserver() {
};

/**
* We should not adjust the values if there's no movement
*/
public void testNoShardMovement() {
final var originalNode0WriteLoadStats = randomUsageStats();
final var originalNode1WriteLoadStats = randomUsageStats();
final var allocation = createRoutingAllocation(originalNode0WriteLoadStats, originalNode1WriteLoadStats);

final var writeLoadPerShardSimulator = new WriteLoadPerShardSimulator(allocation);
final var calculatedNodeUsageStates = writeLoadPerShardSimulator.nodeUsageStatsForThreadPools();
assertThat(calculatedNodeUsageStates, Matchers.aMapWithSize(2));
assertThat(
calculatedNodeUsageStates.get("node_0").threadPoolUsageStatsMap().get("write"),
sameInstance(originalNode0WriteLoadStats)
);
assertThat(
calculatedNodeUsageStates.get("node_1").threadPoolUsageStatsMap().get("write"),
sameInstance(originalNode1WriteLoadStats)
);
}

public void testMovementOfAShardWillReduceThreadPoolUtilisation() {
final var originalNode0WriteLoadStats = randomUsageStats();
final var originalNode1WriteLoadStats = randomUsageStats();
final var allocation = createRoutingAllocation(originalNode0WriteLoadStats, originalNode1WriteLoadStats);
final var writeLoadPerShardSimulator = new WriteLoadPerShardSimulator(allocation);

// Relocate a random shard from node_0 to node_1
final var randomShard = randomFrom(StreamSupport.stream(allocation.routingNodes().node("node_0").spliterator(), false).toList());
final var moveShardTuple = allocation.routingNodes().relocateShard(randomShard, "node_1", randomNonNegativeLong(), "testing", NOOP);
writeLoadPerShardSimulator.simulateShardStarted(moveShardTuple.v2());

final var calculatedNodeUsageStates = writeLoadPerShardSimulator.nodeUsageStatsForThreadPools();
assertThat(calculatedNodeUsageStates, Matchers.aMapWithSize(2));

// Some node_0 utilization should have been moved to node_1
assertThat(
getAverageWritePoolUtilization(writeLoadPerShardSimulator, "node_0"),
lessThan(originalNode0WriteLoadStats.averageThreadPoolUtilization())
);
assertThat(
getAverageWritePoolUtilization(writeLoadPerShardSimulator, "node_1"),
greaterThan(originalNode1WriteLoadStats.averageThreadPoolUtilization())
);
}

public void testMovementFollowedByMovementBackWillNotChangeAnything() {
final var originalNode0WriteLoadStats = randomUsageStats();
final var originalNode1WriteLoadStats = randomUsageStats();
final var allocation = createRoutingAllocation(originalNode0WriteLoadStats, originalNode1WriteLoadStats);
final var writeLoadPerShardSimulator = new WriteLoadPerShardSimulator(allocation);

// Relocate a random shard from node_0 to node_1
final long expectedShardSize = randomNonNegativeLong();
final var randomShard = randomFrom(StreamSupport.stream(allocation.routingNodes().node("node_0").spliterator(), false).toList());
final var moveShardTuple = allocation.routingNodes().relocateShard(randomShard, "node_1", expectedShardSize, "testing", NOOP);
writeLoadPerShardSimulator.simulateShardStarted(moveShardTuple.v2());
final ShardRouting movedAndStartedShard = allocation.routingNodes().startShard(moveShardTuple.v2(), NOOP, expectedShardSize);

// Some node_0 utilization should have been moved to node_1
assertThat(
getAverageWritePoolUtilization(writeLoadPerShardSimulator, "node_0"),
lessThan(originalNode0WriteLoadStats.averageThreadPoolUtilization())
);
assertThat(
getAverageWritePoolUtilization(writeLoadPerShardSimulator, "node_1"),
greaterThan(originalNode1WriteLoadStats.averageThreadPoolUtilization())
);

// Then move it back
final var moveBackTuple = allocation.routingNodes()
.relocateShard(movedAndStartedShard, "node_0", expectedShardSize, "testing", NOOP);
writeLoadPerShardSimulator.simulateShardStarted(moveBackTuple.v2());

// The utilization numbers should be back to their original values
assertThat(
getAverageWritePoolUtilization(writeLoadPerShardSimulator, "node_0"),
equalTo(originalNode0WriteLoadStats.averageThreadPoolUtilization())
);
assertThat(
getAverageWritePoolUtilization(writeLoadPerShardSimulator, "node_1"),
equalTo(originalNode1WriteLoadStats.averageThreadPoolUtilization())
);
}

public void testMovementBetweenNodesWithNoThreadPoolStats() {
final var originalNode0WriteLoadStats = randomBoolean() ? randomUsageStats() : null;
final var originalNode1WriteLoadStats = randomBoolean() ? randomUsageStats() : null;
final var allocation = createRoutingAllocation(originalNode0WriteLoadStats, originalNode1WriteLoadStats);
final var writeLoadPerShardSimulator = new WriteLoadPerShardSimulator(allocation);

// Relocate a random shard from node_0 to node_1
final long expectedShardSize = randomNonNegativeLong();
final var randomShard = randomFrom(StreamSupport.stream(allocation.routingNodes().node("node_0").spliterator(), false).toList());
final var moveShardTuple = allocation.routingNodes().relocateShard(randomShard, "node_1", expectedShardSize, "testing", NOOP);
writeLoadPerShardSimulator.simulateShardStarted(moveShardTuple.v2());
allocation.routingNodes().startShard(moveShardTuple.v2(), NOOP, expectedShardSize);

final var generated = writeLoadPerShardSimulator.nodeUsageStatsForThreadPools();
assertThat(generated.containsKey("node_0"), equalTo(originalNode0WriteLoadStats != null));
assertThat(generated.containsKey("node_1"), equalTo(originalNode1WriteLoadStats != null));
}

private float getAverageWritePoolUtilization(WriteLoadPerShardSimulator writeLoadPerShardSimulator, String nodeId) {
final var generatedNodeUsageStates = writeLoadPerShardSimulator.nodeUsageStatsForThreadPools();
final var node0WritePoolStats = generatedNodeUsageStates.get(nodeId).threadPoolUsageStatsMap().get("write");
return node0WritePoolStats.averageThreadPoolUtilization();
}

private NodeUsageStatsForThreadPools.ThreadPoolUsageStats randomUsageStats() {
return new NodeUsageStatsForThreadPools.ThreadPoolUsageStats(
randomIntBetween(4, 16),
randomFloatBetween(0.1f, 1.0f, true),
randomLongBetween(0, 60_000)
);
}

private RoutingAllocation createRoutingAllocation(
NodeUsageStatsForThreadPools.ThreadPoolUsageStats node0WriteLoadStats,
NodeUsageStatsForThreadPools.ThreadPoolUsageStats node1WriteLoadStats
) {
final Map<String, NodeUsageStatsForThreadPools> nodeUsageStats = new HashMap<>();
if (node0WriteLoadStats != null) {
nodeUsageStats.put("node_0", new NodeUsageStatsForThreadPools("node_0", Map.of("write", node0WriteLoadStats)));
}
if (node1WriteLoadStats != null) {
nodeUsageStats.put("node_1", new NodeUsageStatsForThreadPools("node_1", Map.of("write", node1WriteLoadStats)));
}

return new RoutingAllocation(
new AllocationDeciders(List.of()),
createClusterState(),
ClusterInfo.builder().nodeUsageStatsForThreadPools(nodeUsageStats).build(),
SnapshotShardSizeInfo.EMPTY,
System.nanoTime()
).mutableCloneForSimulation();
}

private ClusterState createClusterState() {
return ClusterStateCreationUtils.stateWithAssignedPrimariesAndReplicas(new String[] { "indexOne", "indexTwo", "indexThree" }, 3, 0);
}
}