Skip to content

feat(ui): canvas stage, zoom, and fit improvements #8085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 6, 2025
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { Spinner } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useAllEntityAdapters } from 'features/controlLayers/contexts/EntityAdapterContext';
import { computed } from 'nanostores';
import { memo, useMemo } from 'react';

export const CanvasBusySpinner = memo(() => {
const canvasManager = useCanvasManager();
const allEntityAdapters = useAllEntityAdapters();
const $isPendingRectCalculation = useMemo(
() =>
computed(
allEntityAdapters.map(({ transformer }) => transformer.$isPendingRectCalculation),
(...values) => values.some((v) => v)
),
[allEntityAdapters]
);
const isPendingRectCalculation = useStore($isPendingRectCalculation);
const isRasterizing = useStore(canvasManager.stateApi.$isRasterizing);
const isCompositing = useStore(canvasManager.compositor.$isBusy);

if (isRasterizing || isCompositing || isPendingRectCalculation) {
return <Spinner opacity={0.3} />;
}
return null;
});
CanvasBusySpinner.displayName = 'CanvasBusySpinner';
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { CanvasAlertsPreserveMask } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsPreserveMask';
import { CanvasAlertsSelectedEntityStatus } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSelectedEntityStatus';
import { CanvasAlertsSendingToGallery } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
import { CanvasBusySpinner } from 'features/controlLayers/components/CanvasBusySpinner';
import { CanvasContextMenuGlobalMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems';
import { CanvasContextMenuSelectedEntityMenuItems } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuSelectedEntityMenuItems';
import { CanvasDropArea } from 'features/controlLayers/components/CanvasDropArea';
Expand Down Expand Up @@ -106,6 +107,9 @@ export const CanvasMainPanelContent = memo(() => {
<MenuContent />
</Menu>
</Flex>
<Flex position="absolute" bottom={4} insetInlineEnd={4}>
<CanvasBusySpinner />
</Flex>
</CanvasManagerProviderGate>
</Flex>
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,33 @@ export const useEntityAdapter = (
assert(adapter, 'useEntityAdapter must be used within a EntityAdapterContext');
return adapter;
};

export const useAllEntityAdapters = () => {
const canvasManager = useCanvasManager();
const regionalGuidanceAdapters = useSyncExternalStore(
canvasManager.adapters.regionMasks.subscribe,
canvasManager.adapters.regionMasks.getSnapshot
);
const rasterLayerAdapters = useSyncExternalStore(
canvasManager.adapters.rasterLayers.subscribe,
canvasManager.adapters.rasterLayers.getSnapshot
);
const controlLayerAdapters = useSyncExternalStore(
canvasManager.adapters.controlLayers.subscribe,
canvasManager.adapters.controlLayers.getSnapshot
);
const inpaintMaskAdapters = useSyncExternalStore(
canvasManager.adapters.inpaintMasks.subscribe,
canvasManager.adapters.inpaintMasks.getSnapshot
);
const allEntityAdapters = useMemo(() => {
return [
...Array.from(rasterLayerAdapters.values()),
...Array.from(controlLayerAdapters.values()),
...Array.from(inpaintMaskAdapters.values()),
...Array.from(regionalGuidanceAdapters.values()),
];
}, [controlLayerAdapters, inpaintMaskAdapters, rasterLayerAdapters, regionalGuidanceAdapters]);

return allEntityAdapters;
};
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ import {
selectCanvasSlice,
selectEntity,
} from 'features/controlLayers/store/selectors';
import {
type CanvasEntityIdentifier,
type CanvasRenderableEntityState,
isRasterLayerEntityIdentifier,
type Rect,
import type {
CanvasEntityIdentifier,
CanvasRenderableEntityState,
LifecycleCallback,
Rect,
} from 'features/controlLayers/store/types';
import { isRasterLayerEntityIdentifier } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
import { atom } from 'nanostores';
Expand All @@ -40,11 +41,6 @@ import stableHash from 'stable-hash';
import { assert } from 'tsafe';
import type { Jsonifiable, JsonObject } from 'type-fest';

// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a
// type assertion below in the `onInit` method, which calls these callbacks.
type InitCallback = (adapter: CanvasEntityAdapter) => Promise<boolean>;

export abstract class CanvasEntityAdapterBase<
T extends CanvasRenderableEntityState,
U extends string,
Expand Down Expand Up @@ -118,7 +114,7 @@ export abstract class CanvasEntityAdapterBase<
/**
* Callbacks that are executed when the module is initialized.
*/
private static initCallbacks = new Set<InitCallback>();
private static initCallbacks = new Set<LifecycleCallback>();

/**
* Register a callback to be run when an entity adapter is initialized.
Expand Down Expand Up @@ -165,7 +161,7 @@ export abstract class CanvasEntityAdapterBase<
* return false;
* });
*/
static registerInitCallback = (callback: InitCallback) => {
static registerInitCallback = (callback: LifecycleCallback) => {
const wrapped = async (adapter: CanvasEntityAdapter) => {
const result = await callback(adapter);
if (result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
roundRect,
} from 'features/controlLayers/konva/util';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, RectWithRotation } from 'features/controlLayers/store/types';
import type { Coordinate, LifecycleCallback, Rect, RectWithRotation } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import Konva from 'konva';
import type { GroupConfig } from 'konva/lib/Group';
Expand Down Expand Up @@ -123,7 +123,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
/**
* Whether the transformer is currently calculating the rect of the parent.
*/
$isPendingRectCalculation = atom<boolean>(true);
$isPendingRectCalculation = atom<boolean>(false);

/**
* A set of subscriptions that should be cleaned up when the transformer is destroyed.
Expand Down Expand Up @@ -177,6 +177,11 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
*/
transformMutex = new Mutex();

/**
* Callbacks that are executed when the bbox is updated.
*/
private static bboxUpdatedCallbacks = new Set<LifecycleCallback>();

konva: {
transformer: Konva.Transformer;
proxyRect: Konva.Rect;
Expand Down Expand Up @@ -908,6 +913,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.renderer.konva.objectGroup.setAttrs(groupAttrs);
this.parent.bufferRenderer.konva.group.setAttrs(groupAttrs);
}

CanvasEntityTransformer.runBboxUpdatedCallbacks(this.parent);
};

calculateRect = debounce(() => {
Expand Down Expand Up @@ -1026,6 +1033,23 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.konva.outlineRect.visible(false);
};

static registerBboxUpdatedCallback = (callback: LifecycleCallback) => {
const wrapped = async (adapter: CanvasEntityAdapter) => {
const result = await callback(adapter);
if (result) {
this.bboxUpdatedCallbacks.delete(wrapped);
}
return result;
};
this.bboxUpdatedCallbacks.add(wrapped);
};

private static runBboxUpdatedCallbacks = (adapter: CanvasEntityAdapter) => {
for (const callback of this.bboxUpdatedCallbacks) {
callback(adapter);
}
};

repr = () => {
return {
id: this.id,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Property } from 'csstype';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { getKonvaNodeDebugAttrs, getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
import type { Coordinate, Dimensions, Rect, StageAttrs } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
Expand Down Expand Up @@ -186,6 +186,18 @@ export class CanvasStageModule extends CanvasModuleBase {
}
};

/**
* Fits the bbox and layers to the stage. The union of the bbox and the visible layers will be centered and scaled
* to fit the stage with some padding.
*/
fitBboxAndLayersToStage = (): void => {
const layersRect = this.manager.compositor.getVisibleRectOfType();
const bboxRect = this.manager.stateApi.getBbox().rect;
const unionRect = getRectUnion(layersRect, bboxRect);
this.log.trace({ bboxRect, layersRect, unionRect }, 'Fitting bbox and layers to stage');
this.fitRect(unionRect);
};

/**
* Fits a rectangle to the stage. The rectangle will be centered and scaled to fit the stage with some padding.
*
Expand Down Expand Up @@ -218,14 +230,23 @@ export class CanvasStageModule extends CanvasModuleBase {
this._intendedScale = scale;
this._activeSnapPoint = null;

this.konva.stage.setAttrs({
const tween = new Konva.Tween({
node: this.konva.stage,
duration: 0.15,
x,
y,
scaleX: scale,
scaleY: scale,
easing: Konva.Easings.EaseInOut,
onUpdate: () => {
this.syncStageAttrs();
},
onFinish: () => {
this.syncStageAttrs();
tween.destroy();
},
});

this.syncStageAttrs({ x, y, scale });
tween.play();
};

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import { zMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
Expand Down Expand Up @@ -611,3 +612,7 @@ export const isMaskEntityIdentifier = (
): entityIdentifier is CanvasEntityIdentifier<'inpaint_mask' | 'regional_guidance'> => {
return isInpaintMaskEntityIdentifier(entityIdentifier) || isRegionalGuidanceEntityIdentifier(entityIdentifier);
};

// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`.
export type LifecycleCallback = (adapter: CanvasEntityAdapter) => Promise<boolean>;
20 changes: 15 additions & 5 deletions invokeai/frontend/web/src/features/imageActions/actions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter, selectDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
import {
Expand Down Expand Up @@ -173,15 +174,24 @@ export const newCanvasFromImage = async (arg: {
imageObject = imageDTOToImageObject(imageDTO);
}

const { x, y } = selectBboxRect(state);
const addFitOnLayerInitCallback = (adapterId: string) => {
CanvasEntityTransformer.registerBboxUpdatedCallback((adapter) => {
// Skip the callback if the adapter is not the one we are creating
if (adapter.id !== adapterId) {
return Promise.resolve(false);
}
adapter.manager.stage.fitBboxAndLayersToStage();
return Promise.resolve(true);
});
};

switch (type) {
case 'raster_layer': {
const overrides = {
id: getPrefixedId('raster_layer'),
objects: [imageObject],
position: { x, y },
} satisfies Partial<CanvasRasterLayerState>;
addFitOnLayerInitCallback(overrides.id);
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
Expand All @@ -192,9 +202,9 @@ export const newCanvasFromImage = async (arg: {
const overrides = {
id: getPrefixedId('control_layer'),
objects: [imageObject],
position: { x, y },
controlAdapter: deepClone(initialControlNet),
} satisfies Partial<CanvasControlLayerState>;
addFitOnLayerInitCallback(overrides.id);
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
Expand All @@ -205,8 +215,8 @@ export const newCanvasFromImage = async (arg: {
const overrides = {
id: getPrefixedId('inpaint_mask'),
objects: [imageObject],
position: { x, y },
} satisfies Partial<CanvasInpaintMaskState>;
addFitOnLayerInitCallback(overrides.id);
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
Expand All @@ -217,8 +227,8 @@ export const newCanvasFromImage = async (arg: {
const overrides = {
id: getPrefixedId('regional_guidance'),
objects: [imageObject],
position: { x, y },
} satisfies Partial<CanvasRegionalGuidanceState>;
addFitOnLayerInitCallback(overrides.id);
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
Expand Down