diff --git a/src/hooks/useRive.tsx b/src/hooks/useRive.tsx index 6abd179..3f0642e 100644 --- a/src/hooks/useRive.tsx +++ b/src/hooks/useRive.tsx @@ -25,6 +25,8 @@ function RiveComponent({ setCanvasRef, className = '', style, + width, + height, ...rest }: RiveComponentProps & ComponentProps<'canvas'>) { const containerStyle = { @@ -39,7 +41,13 @@ function RiveComponent({ className={className} {...(!className && { style: containerStyle })} > - + ); } @@ -130,18 +138,28 @@ export default function useRive( const boundsChanged = width !== dimensions.width || height !== dimensions.height; if (canvasRef.current && rive && boundsChanged) { + const widthProp = canvasRef.current.getAttribute('data-rive-width-prop'); + const heightProp = canvasRef.current.getAttribute('data-rive-height-prop'); if (options.fitCanvasToArtboardHeight) { containerRef.current.style.height = height + 'px'; } if (options.useDevicePixelRatio) { const dpr = window.devicePixelRatio || 1; - canvasRef.current.width = dpr * width; - canvasRef.current.height = dpr * height; + if (!widthProp) { + canvasRef.current.width = dpr * width; + } + if (!heightProp) { + canvasRef.current.height = dpr * height; + } canvasRef.current.style.width = width + 'px'; canvasRef.current.style.height = height + 'px'; } else { - canvasRef.current.width = width; - canvasRef.current.height = height; + if (!widthProp) { + canvasRef.current.width = width; + } + if (!heightProp) { + canvasRef.current.height = height; + } } setDimensions({ width, height }); diff --git a/test/useRive.test.tsx b/test/useRive.test.tsx index 94194a7..2d80a65 100644 --- a/test/useRive.test.tsx +++ b/test/useRive.test.tsx @@ -376,4 +376,35 @@ describe('useRive', () => { ); expect(container.firstChild).not.toHaveStyle('width: 50%'); }); + + it('container bounds do not override user-provided canvas resolutions', async () => { + const params = { + src: 'file-src', + }; + + global.devicePixelRatio = 2; + + const riveMock = { + ...baseRiveMock, + resizeToCanvas: jest.fn(), + }; + // @ts-ignore + mocked(rive.Rive).mockImplementation(() => riveMock); + + const containerSpy = document.createElement('div'); + jest.spyOn(containerSpy, 'clientWidth', 'get').mockReturnValue(600); + jest.spyOn(containerSpy, 'clientHeight', 'get').mockReturnValue(100); + const { result } = renderHook(() => useRive(params)); + + const { RiveComponent: RiveTestComponent } = result.current; + render( + + ); + await act(async () => { + result.current.setContainerRef(containerSpy); + controlledRiveloadCb(); + }); + + expect(result.current.canvas).toHaveAttribute('width', '300'); + }); });