diff --git a/.bazelignore b/.bazelignore index 1721481a919a..8a35bff1acf1 100644 --- a/.bazelignore +++ b/.bazelignore @@ -8,6 +8,7 @@ app/lezer-markdown/node_modules app/project-manager-shim/node_modules app/table-expression/node_modules app/rust-ffi/node_modules +app/ydoc-channel/node_modules app/ydoc-server/node_modules app/ydoc-server-nodejs/node_modules app/ydoc-server-polyglot/node_modules diff --git a/BUILD.bazel b/BUILD.bazel index c32a7af1f89f..56481c2e3adf 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -141,6 +141,7 @@ ENGINE_DIST_SOURCES = STDLIB_SOURCES + SBT_PROJECT_FILES + [ "lib/java/poi-wrapper/src/**", "lib/java/runtime-utils/src/**", "lib/java/scala-libs-wrapper/src/**", + "lib/java/ydoc-api/src/**", "lib/java/ydoc-polyfill/src/**", "lib/java/ydoc-server/src/**", "lib/java/ydoc-server-registration/src/**", diff --git a/app/gui/package.json b/app/gui/package.json index 93f887733143..a7725a2be71c 100644 --- a/app/gui/package.json +++ b/app/gui/package.json @@ -124,6 +124,7 @@ "y-protocols": "^1.0.6", "y-textarea": "^1.0.2", "y-websocket": "^1.5.4", + "ydoc-channel": "workspace:*", "ydoc-shared": "workspace:*", "yjs": "^13.6.21", "zod": "catalog:", diff --git a/app/gui/src/project-view/util/crdt.ts b/app/gui/src/project-view/util/crdt.ts index a27800bea3fa..c646be13555a 100644 --- a/app/gui/src/project-view/util/crdt.ts +++ b/app/gui/src/project-view/util/crdt.ts @@ -47,6 +47,8 @@ interface SubdocsEvent { export type ProviderParams = { /** URL for the project's language server RPC connection. */ ls: string + /** URL for the project's data connection. */ + data: string } /** TODO: Add docs */ diff --git a/app/gui/src/project-view/util/net.ts b/app/gui/src/project-view/util/net.ts index 6baca26d5215..198dfd72ff35 100644 --- a/app/gui/src/project-view/util/net.ts +++ b/app/gui/src/project-view/util/net.ts @@ -1,27 +1,19 @@ import { onScopeDispose } from 'vue' +import { YjsChannel } from 'ydoc-channel' import { AbortScope } from 'ydoc-shared/util/net' -import { - ReconnectingWebSocket, - ReconnectingWebSocketTransport, -} from 'ydoc-shared/util/net/ReconnectingWSTransport' +import { YjsTransport } from 'ydoc-shared/util/net/YjsTransport' +import * as Y from 'yjs' export { AbortScope } -const WS_OPTIONS = { - // We do not want to enqueue any messages, because after reconnecting we have to initProtocol again. - maxEnqueuedMessages: 0, -} - /** TODO: Add docs */ -export function createRpcTransport(url: string): ReconnectingWebSocketTransport { - return new ReconnectingWebSocketTransport(url, WS_OPTIONS) +export function createRpcTransport(indexDoc: Y.Doc, url: string): YjsTransport { + return new YjsTransport(indexDoc, url) } /** TODO: Add docs */ -export function createDataWebsocket(url: string, binaryType: 'arraybuffer' | 'blob'): WebSocket { - const websocket = new ReconnectingWebSocket(url, undefined, WS_OPTIONS) - websocket.binaryType = binaryType - return websocket as WebSocket +export function createDataSocket(indexDoc: Y.Doc, url: string): YjsChannel { + return new YjsChannel(indexDoc, url) } export interface WebSocketHandler { diff --git a/app/gui/src/project-view/util/net/dataServer.ts b/app/gui/src/project-view/util/net/dataServer.ts index 9c3a9d94946c..e79c3f610bed 100644 --- a/app/gui/src/project-view/util/net/dataServer.ts +++ b/app/gui/src/project-view/util/net/dataServer.ts @@ -1,5 +1,6 @@ import { Err, Ok, type Result } from 'enso-common/src/utilities/data/result' import { ObservableV2 } from 'lib0/observable' +import type { YjsChannel } from 'ydoc-channel' import { Builder, ByteBuffer, @@ -60,20 +61,22 @@ export class DataServer extends ObservableV2 { /** `websocket.binaryType` should be `ArrayBuffer`. */ constructor( public clientId: string, - public websocket: WebSocket, + public websocket: YjsChannel, abort: AbortScope, ) { super() abort.handleDispose(this) websocket.addEventListener('message', ({ data: rawPayload }) => { - if (!(rawPayload instanceof ArrayBuffer)) { + if (!ArrayBuffer.isView(rawPayload)) { console.warn('Data Server: Data type was invalid:', rawPayload) // Ignore all non-binary messages. If the messages are `Blob`s instead, this is a // misconfiguration and should also be ignored. return } - const binaryMessage = OutboundMessage.getRootAsOutboundMessage(new ByteBuffer(rawPayload)) + const binaryMessage = OutboundMessage.getRootAsOutboundMessage( + new ByteBuffer(rawPayload.buffer), + ) const payloadType = binaryMessage.payloadType() const payload = binaryMessage.payload(new PAYLOAD_CONSTRUCTOR[payloadType]()) if (!payload) return @@ -101,8 +104,7 @@ export class DataServer extends ObservableV2 { this.scheduleInitializationAfterConnect() }) - if (websocket.readyState === WebSocket.OPEN) this.initialized = this.initialize() - else this.initialized = this.scheduleInitializationAfterConnect() + this.initialized = this.initialize() } /** TODO: Add docs */ diff --git a/app/gui/src/providers/openedProjects/project/project.ts b/app/gui/src/providers/openedProjects/project/project.ts index 12a4cb4a0c01..93e336b032cc 100644 --- a/app/gui/src/providers/openedProjects/project/project.ts +++ b/app/gui/src/providers/openedProjects/project/project.ts @@ -14,7 +14,7 @@ import { nextEvent } from '@/util/data/observable' import type { Opt } from '@/util/data/opt' import { ReactiveMapping } from '@/util/database/reactiveDb' import type { MethodPointer } from '@/util/methodPointer' -import { createDataWebsocket, createRpcTransport, useAbortScope } from '@/util/net' +import { createDataSocket, createRpcTransport, useAbortScope } from '@/util/net' import { DataServer } from '@/util/net/dataServer' import { ProjectPath } from '@/util/projectPath' import { tryQualifiedName, type QualifiedName } from '@/util/qualifiedName' @@ -74,15 +74,17 @@ export function createProjectStore( const doc = new Y.Doc() const awareness = new Awareness(doc) - + const ydocUrl = resolveYDocUrl(props.engine.rpcUrl, props.engine.ydocUrl) + const guiRpcId = `gui-rpc-${crypto.randomUUID()}` const clientId = crypto.randomUUID() as Uuid - const lsRpcConnection = createLsRpcConnection(clientId, props.engine.rpcUrl, abort) + const lsRpcConnection = createLsRpcConnection(clientId, doc, guiRpcId, abort) const projectRootId = lsRpcConnection.contentRoots.then( (roots) => roots.find((root) => root.type === 'Project')?.id, ) onScopeDispose(() => lsRpcConnection.release()) - const dataConnection = initializeDataConnection(clientId, props.engine.dataUrl, abort) + const guiDataId = `gui-data-${crypto.randomUUID()}` + const dataConnection = initializeDataConnection(clientId, doc, guiDataId, abort) const rpcUrl = new URL(props.engine.rpcUrl) const isOnLocalBackend = rpcUrl.protocol === 'mock:' || @@ -91,23 +93,12 @@ export function createProjectStore( rpcUrl.hostname === '[::1]' || rpcUrl.hostname === '0:0:0:0:0:0:0:1' - const moduleProjectPath = computed((): Result | undefined => { - const filePath = observedFileName.value - if (filePath == null) return undefined - const withoutFileExt = filePath.replace(/\.enso$/, '') - const withDotSeparators = withoutFileExt.replace(/\//g, '.') - const qn = tryQualifiedName(withDotSeparators) - if (!qn.ok) return qn - return Ok(ProjectPath.create(undefined, qn.value)) - }) - - const ydocUrl = resolveYDocUrl(props.engine.rpcUrl, props.engine.ydocUrl) let yDocsProvider: ReturnType | undefined watchEffect((onCleanup) => { yDocsProvider = attachProvider( ydocUrl.href, 'index', - { ls: props.engine.rpcUrl }, + { ls: guiRpcId, data: guiDataId }, doc, awareness.internal, ) @@ -117,6 +108,16 @@ export function createProjectStore( }) }) + const moduleProjectPath = computed((): Result | undefined => { + const filePath = observedFileName.value + if (filePath == null) return undefined + const withoutFileExt = filePath.replace(/\.enso$/, '') + const withDotSeparators = withoutFileExt.replace(/\//g, '.') + const qn = tryQualifiedName(withDotSeparators) + if (!qn.ok) return qn + return Ok(ProjectPath.create(undefined, qn.value)) + }) + const projectModel = new DistributedProject(doc) const entryPoint = computed(() => { @@ -443,8 +444,13 @@ function resolveYDocUrl(rpcUrl: string, url: string): URL { return resolved } -function createLsRpcConnection(clientId: Uuid, url: string, abort: AbortScope): LanguageServer { - const transport = createRpcTransport(url) +function createLsRpcConnection( + clientId: Uuid, + doc: Y.Doc, + url: string, + abort: AbortScope, +): LanguageServer { + const transport = createRpcTransport(doc, url) const connection = new LanguageServer(clientId, transport) abort.onAbort(() => { connection.stopReconnecting() @@ -453,8 +459,8 @@ function createLsRpcConnection(clientId: Uuid, url: string, abort: AbortScope): return connection } -function initializeDataConnection(clientId: Uuid, url: string, abort: AbortScope) { - const client = createDataWebsocket(url, 'arraybuffer') +function initializeDataConnection(clientId: Uuid, doc: Y.Doc, url: string, abort: AbortScope) { + const client = createDataSocket(doc, url) const connection = new DataServer(clientId, client, abort) onScopeDispose(() => connection.dispose()) return connection diff --git a/app/project-manager-shim/src/projectService/ensoRunner.ts b/app/project-manager-shim/src/projectService/ensoRunner.ts index ba0246223265..c3336bfb83e3 100644 --- a/app/project-manager-shim/src/projectService/ensoRunner.ts +++ b/app/project-manager-shim/src/projectService/ensoRunner.ts @@ -87,8 +87,14 @@ export class EnsoRunner implements Runner { args: readonly string[], spawnCallback: (cmd: string, cmdArgs: readonly string[]) => T, ) { - const cmd = this.ensoPath.endsWith('.bat') ? 'cmd.exe' : this.ensoPath - const cmdArgs = this.ensoPath.endsWith('.bat') ? ['/c', this.ensoPath, ...args] : args + const [cmd, cmdArgs] = + this.ensoPath.endsWith('.bat') ? + ['cmd.exe', ['/c', this.ensoPath, ...args]] + : [this.ensoPath, args] + const isDevMode = process.env.NODE_ENV === 'development' + if (isDevMode) { + console.log('runProcess', cmd, 'with', cmdArgs) + } return spawnCallback(cmd, cmdArgs) } diff --git a/app/ydoc-channel/BUILD.bazel b/app/ydoc-channel/BUILD.bazel new file mode 100644 index 000000000000..291bc5795187 --- /dev/null +++ b/app/ydoc-channel/BUILD.bazel @@ -0,0 +1,34 @@ +load("@aspect_rules_js//npm:defs.bzl", "npm_package") +load("@aspect_rules_ts//ts:defs.bzl", "ts_config", "ts_project") +load("@npm//:defs.bzl", "npm_link_all_packages", "npm_link_targets") + +npm_link_all_packages(name = "node_modules") + +ts_config( + name = "tsconfig", + src = "tsconfig.json", + deps = ["//:tsconfig"], +) + +ts_project( + name = "tsc", + srcs = glob(["src/**/*.ts"]), + composite = True, + out_dir = "dist", + root_dir = "src", + tsconfig = ":tsconfig", + validate = select({ + "@platforms//os:windows": False, + "//conditions:default": True, + }), + deps = npm_link_targets(), +) + +npm_package( + name = "pkg", + srcs = [ + "package.json", + ":tsc", + ], + visibility = ["//visibility:public"], +) diff --git a/app/ydoc-channel/package.json b/app/ydoc-channel/package.json new file mode 100644 index 000000000000..33b7b17f28a0 --- /dev/null +++ b/app/ydoc-channel/package.json @@ -0,0 +1,28 @@ +{ + "name": "ydoc-channel", + "version": "0.1.0", + "description": "Y.js-based bidirectional communication channel", + "type": "module", + "exports": { + ".": { + "source": "./src/index.ts", + "types": "./dist/index.d.ts", + "import": "./dist/index.js" + } + }, + "main": "src/index.ts", + "scripts": { + "test:unit": "vitest run", + "compile": "tsc", + "lint": "eslint . --cache --max-warnings=0" + }, + "devDependencies": { + "@types/node": "catalog:", + "typescript": "catalog:", + "vitest": "catalog:" + }, + "dependencies": { + "lib0": "^0.2.99", + "yjs": "^13.6.19" + } +} diff --git a/app/ydoc-channel/src/YjsChannel.test.ts b/app/ydoc-channel/src/YjsChannel.test.ts new file mode 100644 index 000000000000..f9958eb67238 --- /dev/null +++ b/app/ydoc-channel/src/YjsChannel.test.ts @@ -0,0 +1,352 @@ +import { describe, expect, it } from 'vitest' +import * as Y from 'yjs' +import { YjsChannel } from './YjsChannel.js' + +// Mock CloseEvent for Node.js environment +if (typeof globalThis.CloseEvent === 'undefined') { + class CloseEvent extends Event { + constructor(type: string) { + super(type) + } + } + ;(globalThis as any).CloseEvent = CloseEvent +} + +describe('YjsChannel', () => { + it('should send and receive messages between two channels', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + // Subscribe channel2 to receive messages + channel2.subscribe((message) => { + receivedMessages.push(message) + }) + + // Send message from channel1 + channel1.send('Hello from channel1') + + // Channel2 should receive the message + expect(receivedMessages).toEqual(['Hello from channel1']) + }) + + it('should not receive its own messages', () => { + const doc = new Y.Doc() + const channel = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + // Subscribe to own channel + channel.subscribe((message) => { + receivedMessages.push(message) + }) + + // Send message from the same channel + channel.send('Hello from myself') + + // Should not receive own message + expect(receivedMessages).toEqual([]) + }) + + it('should allow multiple subscribers', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const received1: string[] = [] + const received2: string[] = [] + + // Multiple subscribers on channel2 + channel2.subscribe((message) => received1.push(message)) + channel2.subscribe((message) => received2.push(message)) + + // Send message from channel1 + channel1.send('Broadcast message') + + // Both subscribers should receive the message + expect(received1).toEqual(['Broadcast message']) + expect(received2).toEqual(['Broadcast message']) + + // Send message from channel1 + channel1.send('Broadcast message 1') + + // Both subscribers should receive the message + expect(received1).toEqual(['Broadcast message', 'Broadcast message 1']) + expect(received2).toEqual(['Broadcast message', 'Broadcast message 1']) + }) + + it('should support unsubscribing', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + // Subscribe and then unsubscribe + const unsubscribe = channel2.subscribe((message) => { + receivedMessages.push(message) + }) + + channel1.send('First message') + unsubscribe() + channel1.send('Second message') + + // Should only receive the first message + expect(receivedMessages).toEqual(['First message']) + }) + + it('should handle complex message types', () => { + interface ComplexMessage { + id: number + data: string + nested: { value: boolean } + } + + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + let receivedMessage: ComplexMessage | undefined + + channel2.subscribe((message) => { + receivedMessage = message + }) + + const testMessage: ComplexMessage = { + id: 42, + data: 'test data', + nested: { value: true }, + } + + channel1.send(testMessage) + + expect(receivedMessage).toEqual(testMessage) + }) + + it('should send and receive ArrayBuffer messages', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + let receivedBuffer: ArrayBuffer | undefined + + channel2.subscribe((message) => { + receivedBuffer = message + }) + + // Create a test ArrayBuffer with some data + const view = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]) + const testBuffer = view.buffer + + channel1.send(testBuffer) + + expect(receivedBuffer).toBeDefined() + expect(receivedBuffer?.byteLength).toBe(8) + + // Verify the contents + const receivedView = new Uint8Array(receivedBuffer!) + expect(Array.from(receivedView)).toEqual([1, 2, 3, 4, 5, 6, 7, 8]) + }) + + it('should send and receive Uint8Array messages', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + let receivedArray: Uint8Array | undefined + channel2.subscribe((message) => { + receivedArray = message + }) + + // Create a test Uint8Array with some data + const view = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]) + + channel1.send(view) + + expect(receivedArray).toBeDefined() + expect(receivedArray?.byteLength).toBe(8) + + // Verify the contents + expect(Array.from(receivedArray!)).toEqual([1, 2, 3, 4, 5, 6, 7, 8]) + }) + + it('should clean up properly when closed', () => { + const doc = new Y.Doc() + const channel = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + channel.subscribe((message) => { + receivedMessages.push(message) + }) + + channel.close() + + // After dispose, the channel should no longer receive messages + const channel2 = new YjsChannel(doc, 'test-channel') + channel2.send('Message after dispose') + + expect(receivedMessages).toEqual([]) + }) + + it('should cleanup internal storage after receiving', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + // Subscribe channel2 to receive messages + channel2.subscribe((message) => { + receivedMessages.push(message) + }) + + // Send message from channel1 + channel1.send('Hello from channel1') + + // Channel2 should receive the message + expect(receivedMessages).toEqual(['Hello from channel1']) + + expect(doc.getArray('test-channel').length).toEqual(0) + }) + + describe('WebSocket-compatible API', () => { + it('should support addEventListener for message events', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + // Use addEventListener to listen for messages + channel2.addEventListener('message', (event) => { + receivedMessages.push(event.data) + }) + + // Send message from channel1 + channel1.send('Hello via addEventListener') + + // Channel2 should receive the message with MessageEvent structure + expect(receivedMessages).toEqual(['Hello via addEventListener']) + }) + + it('should support removeEventListener', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + const listener = (event: MessageEvent) => { + receivedMessages.push(event.data) + } + + // Add and then remove event listener + channel2.addEventListener('message', listener) + channel1.send('First message') + + channel2.removeEventListener('message', listener) + channel1.send('Second message') + + // Should only receive the first message + expect(receivedMessages).toEqual(['First message']) + }) + + it('should support on/off methods', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + const listener = (event: MessageEvent) => { + receivedMessages.push(event.data) + } + + // Use on to add listener + channel2.on('message', listener) + channel1.send('First message') + + // Use off to remove listener + channel2.off('message', listener) + channel1.send('Second message') + + // Should only receive the first message + expect(receivedMessages).toEqual(['First message']) + }) + + it('should support addEventListener with once option', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const receivedMessages: string[] = [] + + // Add listener with once option + channel2.addEventListener( + 'message', + (event) => { + receivedMessages.push(event.data) + }, + { once: true }, + ) + + // Send multiple messages + channel1.send('First message') + channel1.send('Second message') + channel1.send('Third message') + + // Should only receive the first message due to once option + expect(receivedMessages).toEqual(['First message']) + }) + + it('should emit close event when closed', () => { + const doc = new Y.Doc() + const channel = new YjsChannel(doc, 'test-channel') + + let closeEventFired = false + + channel.addEventListener('close', () => { + closeEventFired = true + }) + + channel.close() + + expect(closeEventFired).toBe(true) + }) + + it('should support emitOpen event', () => { + const doc = new Y.Doc() + const channel = new YjsChannel(doc, 'test-channel') + + let openEventFired = false + + channel.addEventListener('open', () => { + openEventFired = true + }) + + expect(openEventFired).toBe(true) + }) + + it('should handle multiple addEventListener calls for the same event', () => { + const doc = new Y.Doc() + const channel1 = new YjsChannel(doc, 'test-channel') + const channel2 = new YjsChannel(doc, 'test-channel') + + const received1: string[] = [] + const received2: string[] = [] + + // Add multiple listeners + channel2.addEventListener('message', (event) => received1.push(event.data)) + channel2.addEventListener('message', (event) => received2.push(event.data)) + + // Send message + channel1.send('Broadcast message') + + // Both listeners should receive the message + expect(received1).toEqual(['Broadcast message']) + expect(received2).toEqual(['Broadcast message']) + }) + }) +}) diff --git a/app/ydoc-channel/src/YjsChannel.ts b/app/ydoc-channel/src/YjsChannel.ts new file mode 100644 index 000000000000..fbfce43aa1ea --- /dev/null +++ b/app/ydoc-channel/src/YjsChannel.ts @@ -0,0 +1,213 @@ +import { ObservableV2 } from 'lib0/observable' +import * as Y from 'yjs' + +interface AddEventListenerOptions { + capture?: boolean + once?: boolean + passive?: boolean + signal?: AbortSignal +} + +/** + * Message handler callback type. + */ +export type MessageHandler = (message: T) => void + +/** + * Callbacks for YjsChannel lifecycle events. + */ +export interface YjsChannelCallbacks { + /** + * Called when the message channel is connected and ready to use. + * @param channel - The connected YjsChannel instance + */ + onConnect(channel: YjsChannel): void +} + +/** + * ObservableV2-compatible event handlers for WebSocketEventMap. + */ +type WebSocketEventHandlers = { + [K in keyof WebSocketEventMap]: (event: WebSocketEventMap[K]) => void +} + +/** + * A bidirectional communication channel backed by Y.Array. + * + * This class allows multiple parties to send and receive messages through a shared + * Y.Array CRDT. Implements WebSocket-like event API for compatibility. + */ +export class YjsChannel extends ObservableV2 { + private readonly senderId: string + private readonly doc: Y.Doc + private readonly array: Y.Array + private readonly handlers: Set> = new Set() + private readonly observeHandler: (event: Y.YArrayEvent, tr: Y.Transaction) => void + + /** + * Creates a new YjsChannel. + * @param doc - The shared Y.Doc document + * @param channelName - The name of the channel (used to get/create the Y.Array) + */ + constructor(doc: Y.Doc, channelName: string) { + super() + this.senderId = crypto.randomUUID() + this.doc = doc + this.array = doc.getArray(channelName) + + this.observeHandler = (event: Y.YArrayEvent, transaction: Y.Transaction) => { + // Only notify handlers if the message is from another sender + if (transaction.origin !== this.senderId) { + doc.transact(() => { + // Process all added items + for (const delta of event.changes.delta) { + if (delta.insert) { + const items = Array.isArray(delta.insert) ? delta.insert : [delta.insert] + for (const item of items) { + this.notifyHandlers(item) + this.array.delete(0) + } + } + } + }, this.senderId) + } + } + + this.array.observe(this.observeHandler) + } + + /** + * Sends a message to the channel. + * @param message - The message to send + */ + send(message: T): void { + this.doc.transact(() => this.array.push([message]), this.senderId) + } + + /** + * Subscribes to messages received from other parties. + * @param handler - The callback to invoke when a message is received + * @returns A function to unsubscribe the handler + */ + subscribe(handler: MessageHandler): () => void { + this.handlers.add(handler) + return () => { + this.handlers.delete(handler) + } + } + + /** + * Removes all message handlers and stops observing the Y.Array. + */ + close(): void { + this.array.unobserve(this.observeHandler) + this.handlers.clear() + this.emitClose() + } + + /** + * Add an event listener to the channel (alias for addEventListener). + */ + override on( + type: K, + cb: (event: WebSocketEventMap[K]) => void, + options?: AddEventListenerOptions, + ): any { + // If subscribing to 'open' event, call the callback immediately + // since the channel is always open after creation + if (type === 'open') { + try { + cb(new Event('open') as WebSocketEventMap[K]) + } catch (e) { + const error = new Error(`YjsChannel error handling open event ${e}`) + ;(error as any).target = e + this.emitError(error) + } + // Don't add to listeners if 'once' option is set + if (options?.once) { + return cb + } + } + + if (options?.once) { + return super.once(type, cb as any) + } else { + return super.on(type, cb as any) + } + } + + /** + * Remove an event listener from the channel (alias for removeEventListener). + */ + override off( + type: K, + cb: (event: WebSocketEventMap[K]) => void, + _options?: AddEventListenerOptions, + ): void { + super.off(type, cb as any) + } + + /** + * WebSocket-compatible addEventListener method. + * Add an event listener to the channel. + */ + addEventListener( + type: K, + cb: (event: WebSocketEventMap[K]) => void, + options?: AddEventListenerOptions, + ): void { + this.on(type, cb, options) + } + + /** + * WebSocket-compatible removeEventListener method. + * Remove an event listener from the channel. + */ + removeEventListener( + type: K, + cb: (event: WebSocketEventMap[K]) => void, + options?: AddEventListenerOptions, + ): void { + this.off(type, cb, options) + } + + /** + * Notifies all subscribed handlers with the received message. + */ + protected notifyHandlers(message: any): void { + // Create a MessageEvent-like object for WebSocket compatibility + const messageEvent = { data: message } as MessageEvent + + // Emit event for addEventListener listeners + super.emit('message', [messageEvent]) + + // Call legacy subscribe handlers for backward compatibility + for (const handler of this.handlers) { + try { + handler(message) + } catch (e) { + const error = new Error(`Failed to handle message: ${message}`) + ;(error as any).target = e + this.emitError(error) + } + } + } + + /** + * Emit a 'close' event to signal the channel is closed. + */ + private emitClose(): void { + super.emit('close', [new CloseEvent('close')]) + } + + /** + * Emit an 'error' event to signal an error occurred. + */ + private emitError(error?: Error): void { + const errorEvent = new Event('error') + if (error) { + ;(errorEvent as any).error = error + } + super.emit('error', [errorEvent]) + } +} diff --git a/app/ydoc-channel/src/index.ts b/app/ydoc-channel/src/index.ts new file mode 100644 index 000000000000..a7e64201579b --- /dev/null +++ b/app/ydoc-channel/src/index.ts @@ -0,0 +1 @@ +export { YjsChannel, type MessageHandler, type YjsChannelCallbacks } from './YjsChannel.js' diff --git a/app/ydoc-channel/tsconfig.json b/app/ydoc-channel/tsconfig.json new file mode 100644 index 000000000000..aaa83845bc60 --- /dev/null +++ b/app/ydoc-channel/tsconfig.json @@ -0,0 +1,11 @@ +{ + "extends": "../../tsconfig.json", + "compilerOptions": { + "outDir": "dist", + "rootDir": "src", + "noEmit": false, + "isolatedModules": true, + "composite": true + }, + "include": ["src/**/*"] +} diff --git a/app/ydoc-server-polyglot/build.mjs b/app/ydoc-server-polyglot/build.mjs index 80f90fd7d334..635fa674b8f5 100644 --- a/app/ydoc-server-polyglot/build.mjs +++ b/app/ydoc-server-polyglot/build.mjs @@ -10,6 +10,10 @@ const globals = { varName: 'zlib', type: 'cjs', }, + 'node:crypto': { + varName: 'crypto', + type: 'cjs', + }, } const ctx = await esbuild.context({ diff --git a/app/ydoc-server-polyglot/package.json b/app/ydoc-server-polyglot/package.json index a3525fec4afc..f2a1e4fcc458 100644 --- a/app/ydoc-server-polyglot/package.json +++ b/app/ydoc-server-polyglot/package.json @@ -15,8 +15,8 @@ "lint": "eslint . --cache --max-warnings=0" }, "dependencies": { - "ydoc-server": "workspace:*", - "ydoc-shared": "workspace:*" + "ydoc-channel": "workspace:*", + "ydoc-server": "workspace:*" }, "devDependencies": { "@fal-works/esbuild-plugin-global-externals": "^2.1.2", diff --git a/app/ydoc-server-polyglot/src/main.ts b/app/ydoc-server-polyglot/src/main.ts index d3d0df871e97..3dc6d75c05df 100644 --- a/app/ydoc-server-polyglot/src/main.ts +++ b/app/ydoc-server-polyglot/src/main.ts @@ -6,14 +6,31 @@ const debug = typeof YDOC_LS_DEBUG != 'undefined' configureAllDebugLogs(debug) +if (YDOC_JSON_CHANNEL_CALLBACKS == undefined) { + throw new Error('YDOC_JSON_CHANNEL_CALLBACKS undefined') +} +if (YDOC_BINARY_CHANNEL_CALLBACKS == undefined) { + throw new Error('YDOC_BINARY_CHANNEL_CALLBACKS undefined') +} + +const ByteBuffer = Java.type('java.nio.ByteBuffer') + const wss = new WebSocketServer({ host, port }) wss.onconnect = (socket, url) => { const doc = docName(url.pathname) const ls = url.searchParams.get('ls') + const data = url.searchParams.get('data') if (doc != null && ls != null) { - console.log('setupGatewayClient', ls, doc) - setupGatewayClient(socket, ls, doc) + setupGatewayClient( + socket, + ls, + data, + doc, + ByteBuffer, + YDOC_JSON_CHANNEL_CALLBACKS, + YDOC_BINARY_CHANNEL_CALLBACKS, + ) } else { console.log('Failed to authenticate user', ls, doc) } diff --git a/app/ydoc-server-polyglot/src/polyglot.d.ts b/app/ydoc-server-polyglot/src/polyglot.d.ts index 83c063a41b2b..323ec43617d6 100644 --- a/app/ydoc-server-polyglot/src/polyglot.d.ts +++ b/app/ydoc-server-polyglot/src/polyglot.d.ts @@ -1,14 +1,22 @@ /** @file Type declarations for environment provided in polyglot JVM runtime. */ +import type { YjsChannelCallbacks } from '../../ydoc-channel/dist/YjsChannel' + declare class WebSocketServer { constructor(config: any) onconnect: ((socket: any, url: any) => any) | null start(): void } +declare class Java { + static type(name: string): any +} + declare const YDOC_HOST: string | undefined declare const YDOC_PORT: number | undefined declare const YDOC_LS_DEBUG: boolean | undefined +declare const YDOC_JSON_CHANNEL_CALLBACKS: YjsChannelCallbacks | undefined +declare const YDOC_BINARY_CHANNEL_CALLBACKS: YjsChannelCallbacks | undefined // rust ffi shims declare function parse_block(code: string): Uint8Array diff --git a/app/ydoc-server/package.json b/app/ydoc-server/package.json index 25f85aa4e989..5e2d533c7732 100644 --- a/app/ydoc-server/package.json +++ b/app/ydoc-server/package.json @@ -28,6 +28,7 @@ "lib0": "^0.2.99", "modern-isomorphic-ws": "^1.0.5", "y-protocols": "^1.0.6", + "ydoc-channel": "workspace:*", "ydoc-shared": "workspace:*", "yjs": "^13.6.21", "zod": "catalog:" diff --git a/app/ydoc-server/src/YjsBinaryChannel.ts b/app/ydoc-server/src/YjsBinaryChannel.ts new file mode 100644 index 000000000000..0223761a96e0 --- /dev/null +++ b/app/ydoc-server/src/YjsBinaryChannel.ts @@ -0,0 +1,66 @@ +import * as map from 'lib0/map' +import { YjsChannel, type MessageHandler, type YjsChannelCallbacks } from 'ydoc-channel' +import * as Y from 'yjs' + +/** + * A Yjs channel that handles binary data communication using ByteBuffer. + * Extends YjsChannel to provide binary message encoding/decoding capabilities. + */ +export class YjsBinaryChannel extends YjsChannel { + private static channels = new Map() + + private readonly callbacks: YjsChannelCallbacks + private readonly ByteBuffer: any + + /** + * Creates a new YjsBinaryChannel instance. + * @param doc - The Yjs document to synchronize + * @param channelName - The name of the channel + * @param callbacks - Callbacks for channel lifecycle events + * @param byteBuffer - Java ByteBuffer class + */ + constructor(doc: Y.Doc, channelName: string, callbacks: YjsChannelCallbacks, byteBuffer: any) { + super(doc, channelName) + this.callbacks = callbacks + this.ByteBuffer = byteBuffer + this.callbacks.onConnect(this) + } + + /** Get a {@link YjsBinaryChannel}. */ + static get( + doc: Y.Doc, + channelName: string, + callbacks: YjsChannelCallbacks, + byteBuffer: any, + ): YjsBinaryChannel { + return map.setIfUndefined(YjsBinaryChannel.channels, channelName, () => { + return new YjsBinaryChannel(doc, channelName, callbacks, byteBuffer) + }) + } + + /** + * Sends a message through the channel. + * Converts the message to a Uint8Array before sending. + * @param message - The message to send + */ + override send(message: any): void { + const arr = new Uint8Array(new ArrayBuffer(message)) + super.send(arr as T) + } + + /** + * Subscribes to incoming messages on the channel. + * Converts incoming Uint8Array messages to Java ByteBuffer before passing to the handler. + * @param handler - The message handler function + * @returns A function to unsubscribe from the channel + */ + override subscribe(handler: MessageHandler): () => void { + const f = (contents: Uint8Array) => { + const bb = this.ByteBuffer.allocateDirect(contents.byteLength) + const arr = new Uint8Array(new ArrayBuffer(bb)) + arr.set(contents) + return handler(bb) + } + return super.subscribe(f as MessageHandler) + } +} diff --git a/app/ydoc-server/src/index.ts b/app/ydoc-server/src/index.ts index 8a9b65778498..b30ebbfdeb7d 100644 --- a/app/ydoc-server/src/index.ts +++ b/app/ydoc-server/src/index.ts @@ -9,11 +9,7 @@ */ import debug from 'debug' -import type { Server } from 'http' -import type { Http2SecureServer } from 'http2' -import type WS from 'modern-isomorphic-ws' -import type { IncomingMessage } from 'node:http' -import { docName, type ConnectionData } from './auth' +import { docName } from './auth' import { deserializeIdMap } from './serialization' import { setupGatewayClient, WSSharedDoc, YjsConnection, type YjsSocket } from './ydoc' @@ -23,74 +19,10 @@ export { deserializeIdMap, docName, setupGatewayClient, WSSharedDoc, YjsConnecti export function configureAllDebugLogs( forceEnable: boolean, customLogger?: (...args: any[]) => any, -): void { +) { for (const debugModule of ['ydoc-server:session', 'ydoc-shared:languageServer']) { const instance = debug(debugModule) if (forceEnable) instance.enabled = true if (customLogger) instance.log = customLogger } } - -/** Create a WebSocket server to host the YDoc coordinating server. */ -export async function createGatewayServer( - httpServer: Server | Http2SecureServer, - overrideLanguageServerUrl?: string, -): Promise { - const { WebSocketServer } = (await import('modern-isomorphic-ws')).default - const { parse } = await import('node:url') - - const wss = new WebSocketServer({ noServer: true }) - wss.on('connection', (ws: WS, _request: IncomingMessage, data: ConnectionData) => { - ws.on('error', onWebSocketError) - try { - const wsArrayBuffer = Object.assign(ws, { binaryType: 'arraybuffer' } as const) - setupGatewayClient(wsArrayBuffer, data.lsUrl, data.doc) - } catch (e) { - if (e instanceof Error) { - onWebSocketError(e) - ws.close(1003, e.message) - } else throw e - } - }) - - httpServer.on('upgrade', (request, socket, head) => { - socket.on('error', onHttpSocketError) - authenticate(request, function next(err, data) { - if (err != null || data == null) { - socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n') - socket.destroy() - return - } - socket.removeListener('error', onHttpSocketError) - wss.handleUpgrade(request, socket, head, function done(ws: WS) { - wss.emit('connection', ws, request, data) - }) - }) - }) - - function onWebSocketError(err: Error) { - console.log('WebSocket error:', err) - } - - function onHttpSocketError(err: Error) { - console.log('HTTP socket error:', err) - } - - function authenticate( - request: IncomingMessage, - callback: (err: Error | null, authData: ConnectionData | null) => void, - ) { - // FIXME: Stub. We don't implement authentication for now. Need to be implemented in combination - // with the language server. - const user = 'mock-user' - - if (request.url == null) return callback(null, null) - const { pathname, query } = parse(request.url, true) - if (pathname == null) return callback(null, null) - const doc = docName(pathname) - const lsUrl = - overrideLanguageServerUrl ?? (typeof query.ls === 'string' ? (query.ls as string) : null) - const data = doc != null ? { lsUrl, doc, user } : null - callback(null, data) - } -} diff --git a/app/ydoc-server/src/languageServerSession.ts b/app/ydoc-server/src/languageServerSession.ts index 1b6e9a64f4c1..79892d3378f1 100644 --- a/app/ydoc-server/src/languageServerSession.ts +++ b/app/ydoc-server/src/languageServerSession.ts @@ -4,6 +4,7 @@ import * as json from 'lib0/json' import * as map from 'lib0/map' import { ObservableV2 } from 'lib0/observable' import * as zlib from 'node:zlib' +import { type YjsChannelCallbacks } from 'ydoc-channel' import * as Ast from 'ydoc-shared/ast' import { astCount } from 'ydoc-shared/ast' import { combineFileParts, splitFileContents, type EnsoFileParts } from 'ydoc-shared/ensoFile' @@ -19,7 +20,7 @@ import type { } from 'ydoc-shared/languageServerTypes' import { assertNever } from 'ydoc-shared/util/assert' import { AbortScope, exponentialBackoff, printingCallbacks } from 'ydoc-shared/util/net' -import { ReconnectingWebSocketTransport } from 'ydoc-shared/util/net/ReconnectingWSTransport' +import { YjsServerTransport } from 'ydoc-shared/util/net/YjsTransport' import { DistributedProject, IdMap, @@ -61,13 +62,13 @@ export class LanguageServerSession { static DEBUG = false /** Create a {@link LanguageServerSession}. */ - constructor(ls: LanguageServer, unregister: () => void) { + constructor(ls: LanguageServer, indexDoc: WSSharedDoc, unregister: () => void) { this.clientScope = new AbortScope() this.docs = new Map() this.retainCount = 0 this.ls = ls this.unregister = unregister - this.indexDoc = new WSSharedDoc() + this.indexDoc = indexDoc this.docs.set('index', this.indexDoc) this.model = new DistributedProject(this.indexDoc.doc) this.projectRootId = null @@ -88,11 +89,14 @@ export class LanguageServerSession { static sessions: Map = new Map() /** Get a {@link LanguageServerSession} by its URL. */ - static get(url: string): LanguageServerSession { + static get(url: string, callbacks: YjsChannelCallbacks): LanguageServerSession { const session = map.setIfUndefined(LanguageServerSession.sessions, url, () => { - const ws = new ReconnectingWebSocketTransport(url) - const ls = new LanguageServer(crypto.randomUUID(), ws) - return new LanguageServerSession(ls, () => LanguageServerSession.sessions.delete(url)) + const indexDoc = new WSSharedDoc() + const transport = new YjsServerTransport(indexDoc.doc, url, callbacks) + const ls = new LanguageServer(crypto.randomUUID(), transport) + return new LanguageServerSession(ls, indexDoc, () => + LanguageServerSession.sessions.delete(url), + ) }) session.retain() return session diff --git a/app/ydoc-server/src/ydoc.ts b/app/ydoc-server/src/ydoc.ts index f1eb79876dee..d35ccbeeebc7 100644 --- a/app/ydoc-server/src/ydoc.ts +++ b/app/ydoc-server/src/ydoc.ts @@ -10,6 +10,8 @@ import * as Y from 'yjs' import * as decoding from 'lib0/decoding' import * as encoding from 'lib0/encoding' import { ObservableV2 } from 'lib0/observable' +import type { YjsChannelCallbacks } from 'ydoc-channel' +import { YjsBinaryChannel } from './YjsBinaryChannel' import { LanguageServerSession } from './languageServerSession' const pingTimeout = 30000 @@ -96,19 +98,31 @@ export class WSSharedDoc { export function setupGatewayClient( ws: YjsSocket, lsUrl: string | undefined | null, + dataUrl: string | undefined | null, docName: string, + byteBuffer: any, + jsonChannelCallbacks: YjsChannelCallbacks, + binaryChannelCallbacks: YjsChannelCallbacks, ): void { - console.log(`setupGatewayClient(${lsUrl ? 'lsUrl: ' + lsUrl : 'no lsUrl'}, docName: ${docName})`) - const lsSession = getSessionForUrl(lsUrl) + console.log( + `setupGatewayClient(${lsUrl ? 'lsUrl: ' + lsUrl : 'no lsUrl'}, ${dataUrl ? 'dataUrl: ' + dataUrl : 'no dataUrl'} docName: ${docName}), byteBuffer: ${byteBuffer}, jsonChannelCallbacks: ${jsonChannelCallbacks}, binaryChannelCallbacks: ${binaryChannelCallbacks}`, + ) + const lsSession = getSessionForUrl(lsUrl, jsonChannelCallbacks) const wsDoc = getSessionDoc(lsSession, docName) if (!wsDoc) { ws.close() return } + let dataSocket: YjsBinaryChannel | undefined + if (dataUrl) { + dataSocket = YjsBinaryChannel.get(wsDoc.doc, dataUrl, binaryChannelCallbacks, byteBuffer) + } + const connection = new YjsConnection(ws, wsDoc) connection.once('close', async () => { try { + dataSocket?.close() await lsSession.release() } catch (error) { console.error('Session release failed.\n', error) @@ -116,10 +130,13 @@ export function setupGatewayClient( }) } -function getSessionForUrl(lsUrl: string | undefined | null) { +function getSessionForUrl( + lsUrl: string | undefined | null, + jsonChannelCallbacks: YjsChannelCallbacks, +) { let lsSession: LanguageServerSession if (lsUrl) { - lsSession = LanguageServerSession.get(lsUrl) + lsSession = LanguageServerSession.get(lsUrl, jsonChannelCallbacks) } else { const anySession = LanguageServerSession.sessions.values().next().value if (LanguageServerSession.sessions.size === 1 && anySession) { diff --git a/app/ydoc-server/tsconfig.json b/app/ydoc-server/tsconfig.json index 40dc02f22304..109536b8f7b9 100644 --- a/app/ydoc-server/tsconfig.json +++ b/app/ydoc-server/tsconfig.json @@ -3,6 +3,7 @@ "compilerOptions": { "composite": true, "noEmit": false, "outDir": "dist", "rootDir": "src" }, "extends": "../../tsconfig.json", "files": [ + "./src/YjsBinaryChannel.ts", "./src/__tests__/edits.bench.ts", "./src/__tests__/edits.test.ts", "./src/auth.ts", diff --git a/app/ydoc-shared/package.json b/app/ydoc-shared/package.json index 5e32d6b31c35..414110075ab2 100644 --- a/app/ydoc-shared/package.json +++ b/app/ydoc-shared/package.json @@ -42,6 +42,7 @@ "partysocket": "^1.0.3", "rust-ffi": "workspace:*", "yjs": "^13.6.21", + "ydoc-channel": "workspace:*", "zod": "catalog:" }, "devDependencies": { diff --git a/app/ydoc-shared/src/languageServer.ts b/app/ydoc-shared/src/languageServer.ts index 55ee22e7b6ca..0b4f0b25b8e0 100644 --- a/app/ydoc-shared/src/languageServer.ts +++ b/app/ydoc-shared/src/languageServer.ts @@ -26,7 +26,7 @@ import type { VisualizationConfiguration, } from './languageServerTypes' import { AbortScope, exponentialBackoff } from './util/net' -import type { ReconnectingWebSocketTransport } from './util/net/ReconnectingWSTransport' +import type { YjsTransport } from './util/net/YjsTransport' import { isHeadless } from './util/types' import type { Uuid } from './yjsModel' @@ -149,7 +149,7 @@ export class LanguageServer extends ObservableV2 { - return this.initialized.then((result) => (result.ok ? result.value.contentRoots : [])) + return this.initialized.then((result) => { + return result.ok ? result.value.contentRoots : [] + }) } /** Reconnect the underlying network transport. */ reconnect() { - this.transport.reconnect() + console.log('LanguageServer.reconnect()') + this.transport.close() + this.transport.connect() } // The "magic bag of holding" generic that is only present in the return type is UNSOUND. diff --git a/app/ydoc-shared/src/util/net/YjsTransport.ts b/app/ydoc-shared/src/util/net/YjsTransport.ts new file mode 100644 index 000000000000..0ec7f62e7f25 --- /dev/null +++ b/app/ydoc-shared/src/util/net/YjsTransport.ts @@ -0,0 +1,185 @@ +/** + * A JSON-RPC transport implementation that uses YjsChannel for communication. + * This allows JSON-RPC to work over Y.js CRDT synchronization. + */ + +import { JSONRPCError } from '@open-rpc/client-js' +import { ERR_UNKNOWN } from '@open-rpc/client-js/build/Error.js' +import { + getBatchRequests, + getNotifications, + type JSONRPCRequestData, +} from '@open-rpc/client-js/build/Request.js' +import { Transport } from '@open-rpc/client-js/build/transports/Transport.js' +import type { YjsChannelCallbacks } from 'ydoc-channel' +import { YjsChannel } from 'ydoc-channel' +import type * as Y from 'yjs' + +export interface AddEventListenerOptions { + capture?: boolean + once?: boolean + passive?: boolean + signal?: AbortSignal +} + +/** Event map for YjsTransport events. */ +interface YjsEventMap { + open: Event + close: CloseEvent + message: MessageEvent + error: ErrorEvent +} + +type EventListener = (event: T) => void + +/** A JSON-RPC transport that uses YjsChannel for communication. */ +export class YjsTransport extends Transport { + protected channel: YjsChannel + protected doc: Y.Doc + protected channelName: string + protected eventListeners: Map>> = new Map() + + /** + * Create a {@link YjsTransport}. + * @param doc - The shared Y.Doc document + * @param channelName - The name of the channel (used to get/create the Y.Array) + */ + constructor(doc: Y.Doc, channelName: string) { + super() + this.doc = doc + this.channelName = channelName + this.channel = new YjsChannel(doc, channelName) + } + + /** + * Initiate the channel subscription. + */ + public connect(): Promise { + return new Promise((resolve) => { + this.channel.subscribe((message) => { + this.emit('message', new MessageEvent('message', { data: message })) + this.transportRequestManager.resolveResponse(message) + }) + this.emit('open', new Event('open')) + resolve() + }) + } + + /** + * Send JSON-RPC data through the channel. + */ + public async sendData(data: JSONRPCRequestData, timeout: number | null = 5000): Promise { + let prom = this.transportRequestManager.addRequest(data, timeout) + const notifications = getNotifications(data) + try { + const message = JSON.stringify(this.parseData(data)) + this.channel.send(message) + this.transportRequestManager.settlePendingRequest(notifications) + } catch (err) { + const jsonError = new JSONRPCError((err as any).message, ERR_UNKNOWN, err) + + this.emit('error', new ErrorEvent('error', { error: err, message: (err as any).message })) + this.transportRequestManager.settlePendingRequest(notifications, jsonError) + this.transportRequestManager.settlePendingRequest(getBatchRequests(data), jsonError) + + prom = Promise.reject(jsonError) + } + + return prom + } + + /** Close the channel and clean up subscriptions. */ + public close(): void { + this.channel.close() + this.emit('close', new CloseEvent('close')) + } + + /** Add an event listener. */ + on( + type: K, + cb: (event: YjsEventMap[K]) => void, + options?: AddEventListenerOptions, + ): void { + if (!this.eventListeners.has(type)) { + this.eventListeners.set(type, new Set()) + } + + const wrappedCb = (event: YjsEventMap[K]) => { + cb(event) + if (options?.once) { + this.off(type, cb) + } + } + + // Store original callback for later removal + ;(wrappedCb as any).__original = cb + + this.eventListeners.get(type)!.add(wrappedCb) + + // Handle abort signal + if (options?.signal) { + options.signal.addEventListener('abort', () => { + this.off(type, cb) + }) + } + } + + /** Remove an event listener. */ + off(type: K, cb: (event: YjsEventMap[K]) => void): void { + const listeners = this.eventListeners.get(type) + if (!listeners) return + + // Find and remove the listener with matching original callback + for (const listener of listeners) { + if ((listener as any).__original === cb || listener === cb) { + listeners.delete(listener) + break + } + } + } + + /** Emit an event to all registered listeners. */ + protected emit(type: K, event: YjsEventMap[K]): void { + const listeners = this.eventListeners.get(type) + if (!listeners) return + + for (const listener of listeners) { + listener(event) + } + } +} + +/** A JSON-RPC transport that uses YjsChannel for communication. */ +export class YjsServerTransport extends YjsTransport { + private readonly proxyChannel: YjsChannel + private readonly callbacks: YjsChannelCallbacks + + /** + * Create a {@link YjsTransport}. + * @param doc - The shared Y.Doc document + * @param channelName - The name of the channel (used to get/create the Y.Array) + */ + constructor(doc: Y.Doc, channelName: string, callbacks: YjsChannelCallbacks) { + super(doc, `backend-${channelName}`) + this.callbacks = callbacks + this.proxyChannel = new YjsChannel(doc, channelName) + } + + /** + * Initiate the channel subscription. + */ + override connect(): Promise { + const proxyConnect = new Promise((resolve) => { + this.callbacks.onConnect(this.proxyChannel) + this.callbacks.onConnect(new YjsChannel(this.doc, this.channelName)) + resolve() + }) + return proxyConnect.then(() => super.connect()) + } + + /** Close the channel and clean up subscriptions. */ + override close(): void { + this.proxyChannel.close() + super.close() + } +} diff --git a/app/ydoc-shared/src/util/net/__tests__/YjsTransport.test.ts b/app/ydoc-shared/src/util/net/__tests__/YjsTransport.test.ts new file mode 100644 index 000000000000..5d2f486c0108 --- /dev/null +++ b/app/ydoc-shared/src/util/net/__tests__/YjsTransport.test.ts @@ -0,0 +1,236 @@ +import type { IJSONRPCData } from '@open-rpc/client-js/build/Request.js' +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest' +import * as Y from 'yjs' +import { YjsTransport } from '../YjsTransport' + +// Helper function to create JSON-RPC notification data (no id = no response expected) +function createNotification(method: string, params?: any): IJSONRPCData { + return { + internalID: Math.random(), + request: { + jsonrpc: '2.0', + method, + ...(params === undefined ? { params: null } : { params }), + }, + } +} + +// Polyfill DOM event types for Node.js environment +if (typeof globalThis.CloseEvent === 'undefined') { + ;(globalThis as any).CloseEvent = class CloseEvent extends Event { + constructor(type: string, options?: any) { + super(type, options) + } + } +} + +if (typeof globalThis.MessageEvent === 'undefined') { + ;(globalThis as any).MessageEvent = class MessageEvent extends Event { + data: any + constructor(type: string, options?: any) { + super(type, options) + this.data = options?.data + } + } +} + +if (typeof globalThis.ErrorEvent === 'undefined') { + ;(globalThis as any).ErrorEvent = class ErrorEvent extends Event { + error: any + message: string + constructor(type: string, options?: any) { + super(type, options) + this.error = options?.error + this.message = options?.message || '' + } + } +} + +describe('YjsTransport', () => { + let doc: Y.Doc + let transport1: YjsTransport + let transport2: YjsTransport + + beforeEach(async () => { + doc = new Y.Doc() + transport1 = new YjsTransport(doc, 'test-channel') + transport2 = new YjsTransport(doc, 'test-channel') + await transport1.connect() + await transport2.connect() + }) + + afterEach(() => { + transport1.close() + transport2.close() + }) + + test('sends message from one transport and receives on another', async () => { + const messageListener = vi.fn() + transport2.on('message', messageListener) + + // Send a notification through sendData + const notification = createNotification('test.method', { test: 'data', value: 123 }) + await transport1.sendData(notification) + + // Wait a bit for the message to propagate + await new Promise((resolve) => setTimeout(resolve, 10)) + + expect(messageListener).toHaveBeenCalledTimes(1) + expect(messageListener).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'message', + }), + ) + // Verify the message contains the notification data + const receivedData = JSON.parse(messageListener.mock.calls[0]?.[0].data) + expect(receivedData).toMatchObject({ + jsonrpc: '2.0', + method: 'test.method', + params: { test: 'data', value: 123 }, + }) + }) + + test('multiple transports can exchange messages bidirectionally', async () => { + const listener1 = vi.fn() + const listener2 = vi.fn() + + transport1.on('message', listener1) + transport2.on('message', listener2) + + // Send from transport1 to transport2 + const notification1 = createNotification('message', { text: 'from transport1' }) + await transport1.sendData(notification1) + + await new Promise((resolve) => setTimeout(resolve, 10)) + + expect(listener2).toHaveBeenCalledTimes(1) + expect(listener2).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'message', + }), + ) + expect(listener1).not.toHaveBeenCalled() + + // Send from transport2 to transport1 + const notification2 = createNotification('message', { text: 'from transport2' }) + await transport2.sendData(notification2) + + await new Promise((resolve) => setTimeout(resolve, 10)) + + expect(listener1).toHaveBeenCalledTimes(1) + expect(listener1).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'message', + }), + ) + expect(listener2).toHaveBeenCalledTimes(1) // Still only 1 call + }) + + test('subscribes and unsubscribes to message events', async () => { + const listener = vi.fn() + + transport2.on('message', listener) + + // Send message - should be received + await transport1.sendData(createNotification('test', { id: 1 })) + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(listener).toHaveBeenCalledTimes(1) + + // Unsubscribe + transport2.off('message', listener) + + // Send another message - should NOT be received + await transport1.sendData(createNotification('test', { id: 2 })) + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(listener).toHaveBeenCalledTimes(1) // Still only 1 call + }) + + test('supports multiple listeners for same event', async () => { + const listener1 = vi.fn() + const listener2 = vi.fn() + const listener3 = vi.fn() + + transport2.on('message', listener1) + transport2.on('message', listener2) + transport2.on('message', listener3) + + await transport1.sendData(createNotification('test')) + await new Promise((resolve) => setTimeout(resolve, 10)) + + expect(listener1).toHaveBeenCalledTimes(1) + expect(listener2).toHaveBeenCalledTimes(1) + expect(listener3).toHaveBeenCalledTimes(1) + }) + + test('supports once option for one-time listeners', async () => { + const listener = vi.fn() + + transport2.on('message', listener, { once: true }) + + await transport1.sendData(createNotification('test', { id: 1 })) + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(listener).toHaveBeenCalledTimes(1) + + await transport1.sendData(createNotification('test', { id: 2 })) + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(listener).toHaveBeenCalledTimes(1) // Still only 1 + }) + + test('supports abort signal for cancellable listeners', async () => { + const abortController = new AbortController() + const listener = vi.fn() + + transport2.on('message', listener, { signal: abortController.signal }) + + await transport1.sendData(createNotification('test', { id: 1 })) + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(listener).toHaveBeenCalledTimes(1) + + // Abort the listener + abortController.abort() + + await transport1.sendData(createNotification('test', { id: 2 })) + await new Promise((resolve) => setTimeout(resolve, 10)) + expect(listener).toHaveBeenCalledTimes(1) // Still only 1 + }) + + test('emits open event on connect', async () => { + const transport3 = new YjsTransport(doc, 'test-channel-2') + const openListener = vi.fn() + + transport3.on('open', openListener) + await transport3.connect() + + expect(openListener).toHaveBeenCalledTimes(1) + expect(openListener).toHaveBeenCalledWith(expect.objectContaining({ type: 'open' })) + + transport3.close() + }) + + test('emits close event on close', async () => { + const transport3 = new YjsTransport(doc, 'test-channel-3') + const closeListener = vi.fn() + + transport3.on('close', closeListener) + await transport3.connect() + transport3.close() + + expect(closeListener).toHaveBeenCalledTimes(1) + expect(closeListener).toHaveBeenCalledWith(expect.objectContaining({ type: 'close' })) + }) + + test('can close without connecting', () => { + const transport3 = new YjsTransport(doc, 'test-channel-6') + expect(() => transport3.close()).not.toThrow() + }) + + test('does not receive own messages', async () => { + const listener = vi.fn() + transport1.on('message', listener) + + await transport1.sendData(createNotification('self.message')) + await new Promise((resolve) => setTimeout(resolve, 10)) + + expect(listener).not.toHaveBeenCalled() + }) +}) diff --git a/app/ydoc-shared/tsconfig.json b/app/ydoc-shared/tsconfig.json index a2dc451679f5..92c23c160034 100644 --- a/app/ydoc-shared/tsconfig.json +++ b/app/ydoc-shared/tsconfig.json @@ -50,6 +50,8 @@ "./src/util/net.ts", "./src/util/net/MockWSTransport.ts", "./src/util/net/ReconnectingWSTransport.ts", + "./src/util/net/YjsTransport.ts", + "./src/util/net/__tests__/YjsTransport.test.ts", "./src/util/types.ts", "./src/uuid.ts", "./src/yjsModel.ts" diff --git a/build.sbt b/build.sbt index b63d9ebc9dc3..db76e917f87c 100644 --- a/build.sbt +++ b/build.sbt @@ -428,6 +428,7 @@ lazy val enso = (project in file(".")) `test-utils`, `text-buffer`, `version-output`, + `ydoc-api`, `ydoc-polyfill`, `ydoc-server`, `ydoc-server-registration`, @@ -577,6 +578,7 @@ lazy val componentModulesPaths = (`task-progress-notifications` / Compile / exportedModuleBin).value, (`text-buffer` / Compile / exportedModuleBin).value, (`version-output` / Compile / exportedModuleBin).value, + (`ydoc-api` / Compile / exportedModuleBin).value, (`ydoc-polyfill` / Compile / exportedModuleBin).value, (`ydoc-server` / Compile / exportedModuleBin).value, (`ydoc-server-registration` / Compile / exportedModuleBin).value, @@ -1705,9 +1707,11 @@ lazy val `json-rpc-server` = project Compile / moduleDependencies ++= slf4jApi, Compile / internalModuleDependencies := Seq( (`akka-wrapper` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`scala-libs-wrapper` / Compile / exportedModule).value ) ) + .dependsOn(`ydoc-api`) .dependsOn(`runtime-utils` % "test->compile") // An automatic JPMS module @@ -1763,6 +1767,20 @@ lazy val searcher = project .dependsOn(`polyglot-api`) .dependsOn(testkit % Test) +lazy val `ydoc-api` = project + .in(file("lib/java/ydoc-api")) + .enablePlugins(JPMSPlugin) + .configs(Test) + .settings( + customFrgaalJavaCompilerSettings("21"), + javaModuleName := "org.enso.ydoc.api", + Compile / exportJars := true, + crossPaths := false, + autoScalaLibrary := false, + Test / fork := true, + commands += WithDebugCommand.withDebug + ) + lazy val `ydoc-polyfill` = project .in(file("lib/java/ydoc-polyfill")) .enablePlugins(JPMSPlugin) @@ -1794,6 +1812,7 @@ lazy val `ydoc-polyfill` = project .map(_ % "provided") ++ GraalVM.chromeInspectorPkgs ++ helidon } ) + .dependsOn(`ydoc-api`) .dependsOn(`syntax-rust-definition`) lazy val `ydoc-server` = project @@ -1812,6 +1831,7 @@ lazy val `ydoc-server` = project GraalVM.modules ++ GraalVM.jsPkgs ++ GraalVM.chromeInspectorPkgs ++ helidon ++ logbackPkg ++ slf4jApi, Compile / internalModuleDependencies := Seq( (`syntax-rust-definition` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`ydoc-polyfill` / Compile / exportedModule).value ), libraryDependencies ++= slf4jApi ++ Seq( @@ -1908,6 +1928,7 @@ lazy val `ydoc-server` = project ) .dependsOn(`jvm-interop`) .dependsOn(`logging-service-logback`) + .dependsOn(`ydoc-api`) .dependsOn(`ydoc-polyfill`) lazy val `ydoc-server-registration` = project @@ -1925,6 +1946,7 @@ lazy val `ydoc-server-registration` = project GraalVM.modules, Compile / internalModuleDependencies := Seq( (`engine-runner-common` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`jvm-channel` / Compile / exportedModule).value, (`jvm-interop` / Compile / exportedModule).value ), @@ -1939,6 +1961,7 @@ lazy val `ydoc-server-registration` = project } ) .dependsOn(`engine-runner-common`) + .dependsOn(`ydoc-api`) .dependsOn(`jvm-channel`) .dependsOn(`jvm-interop`) @@ -2222,6 +2245,7 @@ lazy val `language-server` = (project in file("engine/language-server")) (`connected-lock-manager-server` / Compile / exportedModule).value, (`language-server-deps-wrapper` / Compile / exportedModule).value, (`engine-runner-common` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`ydoc-polyfill` / Compile / exportedModule).value, (`engine-common` / Compile / exportedModule).value, (`library-manager` / Compile / exportedModule).value, @@ -2345,6 +2369,7 @@ lazy val `language-server` = (project in file("engine/language-server")) (`task-progress-notifications` / Compile / exportedModule).value, (`text-buffer` / Compile / exportedModule).value, (`version-output` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`ydoc-polyfill` / Compile / exportedModule).value ), Test / javaOptions ++= testLogProviderOptions, @@ -2362,6 +2387,7 @@ lazy val `language-server` = (project in file("engine/language-server")) javaModuleName.value, (`syntax-rust-definition` / javaModuleName).value, (`profiling-utils` / javaModuleName).value, + (`ydoc-api` / javaModuleName).value, (`ydoc-polyfill` / javaModuleName).value, (`library-manager` / javaModuleName).value ), @@ -2414,6 +2440,7 @@ lazy val `language-server` = (project in file("engine/language-server")) .dependsOn(testkit % Test) .dependsOn(`text-buffer`) .dependsOn(`version-output`) + .dependsOn(`ydoc-api`) .dependsOn(`ydoc-polyfill`) lazy val cleanInstruments = taskKey[Unit]( @@ -3645,12 +3672,14 @@ lazy val `engine-runner-common` = project (`library-manager` / Compile / exportedModule).value, (`logging-utils` / Compile / exportedModule).value, (`pkg` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`polyglot-api` / Compile / exportedModule).value ) ) .dependsOn(`edition-updater`) .dependsOn(`library-manager`) .dependsOn(`polyglot-api`) + .dependsOn(`ydoc-api`) .dependsOn(testkit % Test) lazy val `engine-runner` = project @@ -3709,6 +3738,7 @@ lazy val `engine-runner` = project (`runtime-version-manager` / Compile / exportedModule).value, (`semver` / Compile / exportedModule).value, (`version-output` / Compile / exportedModule).value, + (`ydoc-api` / Compile / exportedModule).value, (`ydoc-server-registration` / Compile / exportedModule).value ), // Runtime / modulePath is used as module-path for the native image build. @@ -4292,11 +4322,12 @@ lazy val `jvm-interop` = (Test / fork) := true, commands += WithDebugCommand.withDebug, libraryDependencies ++= Seq( - "org.graalvm.truffle" % "truffle-api" % graalMavenPackagesVersion % "provided", - "org.graalvm.truffle" % "truffle-dsl-processor" % graalMavenPackagesVersion % "provided", - "org.graalvm.sdk" % "graal-sdk" % graalMavenPackagesVersion % Test, - "junit" % "junit" % junitVersion % Test, - "com.github.sbt" % "junit-interface" % junitIfVersion % Test + "org.graalvm.truffle" % "truffle-api" % graalMavenPackagesVersion % "provided", + "org.graalvm.truffle" % "truffle-dsl-processor" % graalMavenPackagesVersion % "provided", + "org.graalvm.sdk" % "graal-sdk" % graalMavenPackagesVersion % Test, + "junit" % "junit" % junitVersion % Test, + "com.github.sbt" % "junit-interface" % junitIfVersion % Test, + "org.graalvm.polyglot" % "js-community" % graalMavenPackagesVersion % Test ), Compile / moduleDependencies ++= Seq( "org.graalvm.truffle" % "truffle-api" % graalMavenPackagesVersion, @@ -4308,7 +4339,8 @@ lazy val `jvm-interop` = (`jvm-channel` / Compile / exportedModule).value, (`engine-common` / Compile / exportedModule).value, (`persistance` / Compile / exportedModule).value - ) + ), + Test / libraryDependencies ++= GraalVM.jsPkgs ) .dependsOn(`engine-common`) .dependsOn(`jvm-channel`) @@ -6430,7 +6462,7 @@ lazy val lintEnso = "Run Enso linter on one or many projects. If no arguments are specified, all projects are linted. Otherwise, the argument should be the full path or just the name of the project to lint." ) lintEnso := { - buildEngineDistributionNoIndex.value + if (sys.env.get("CI").isEmpty) buildEngineDistributionNoIndex.value val fileTree = fileTreeView.value val args: Seq[String] = spaceDelimited("").parsed diff --git a/engine/language-server/src/main/java/module-info.java b/engine/language-server/src/main/java/module-info.java index 5509bcfb4f06..651c3a458275 100644 --- a/engine/language-server/src/main/java/module-info.java +++ b/engine/language-server/src/main/java/module-info.java @@ -37,6 +37,7 @@ requires org.enso.version.output; requires org.enso.text.buffer; requires org.enso.task.progress.notifications; + requires org.enso.ydoc.api; requires org.enso.ydoc.polyfill; exports org.enso.languageserver.filemanager to scala.library; diff --git a/engine/language-server/src/main/java/org/enso/languageserver/boot/resource/TruffleContextInitialization.java b/engine/language-server/src/main/java/org/enso/languageserver/boot/resource/TruffleContextInitialization.java index 4ddbd14d674e..7526d3020dde 100644 --- a/engine/language-server/src/main/java/org/enso/languageserver/boot/resource/TruffleContextInitialization.java +++ b/engine/language-server/src/main/java/org/enso/languageserver/boot/resource/TruffleContextInitialization.java @@ -6,6 +6,7 @@ import org.enso.common.LanguageInfo; import org.enso.languageserver.boot.ComponentSupervisor; import org.enso.languageserver.event.InitializedEvent; +import org.graalvm.polyglot.Context; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -15,6 +16,7 @@ public class TruffleContextInitialization extends LockedInitialization { private final ContextFactory contextFactory; private final ComponentSupervisor supervisor; private final EventStream eventStream; + private Context context; private final Logger logger = LoggerFactory.getLogger(this.getClass()); @@ -45,7 +47,12 @@ public void initComponent() { logger.trace("Created Runtime context [{}]", truffleContext); logger.trace("Initializing Runtime context [{}]", truffleContext); truffleContext.initialize(LanguageInfo.ID); + this.context = truffleContext; eventStream.publish(InitializedEvent.TruffleContextInitialized$.MODULE$); logger.trace("Initialized Runtime context [{}]", truffleContext); } + + public Context getContext() { + return context; + } } diff --git a/engine/language-server/src/main/resources/META-INF/native-image/org/enso/languageserver/reachability-metadata.json b/engine/language-server/src/main/resources/META-INF/native-image/org/enso/languageserver/reachability-metadata.json index 5aaf8890572c..849518984316 100644 --- a/engine/language-server/src/main/resources/META-INF/native-image/org/enso/languageserver/reachability-metadata.json +++ b/engine/language-server/src/main/resources/META-INF/native-image/org/enso/languageserver/reachability-metadata.json @@ -7,6 +7,50 @@ "allPublicConstructors": true, "allDeclaredFields": true, "allDeclaredMethods": true + }, + { + "type": "org.enso.languageserver.http.server.BinaryYdocServer$BinaryServerCallbacks", + "methods": [ + { + "name": "onConnect", + "parameterTypes": ["org.enso.ydoc.api.YjsChannel"] + } + ] + }, + { + "type": "org.enso.jsonrpc.YdocJsonRpcServer$ServerCallbacks", + "methods": [ + { + "name": "onConnect", + "parameterTypes": ["org.enso.ydoc.api.YjsChannel"] + } + ] + }, + { + "type": "org.enso.ydoc.server.YjsChannelSynchronized", + "allDeclaredMethods": true, + "allPublicMethods": true + }, + { + "type": { + "proxy": [ + "java.util.function.Consumer" + ] + } + }, + { + "type": { + "proxy": [ + "org.enso.ydoc.api.YjsChannelCallbacks" + ] + } + }, + { + "type": { + "proxy": [ + "org.enso.ydoc.api.YjsChannel" + ] + } } ] -} \ No newline at end of file +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/boot/LanguageServerComponent.scala b/engine/language-server/src/main/scala/org/enso/languageserver/boot/LanguageServerComponent.scala index a1b4720c0f1d..c48ac04f0247 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/boot/LanguageServerComponent.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/boot/LanguageServerComponent.scala @@ -57,38 +57,17 @@ class LanguageServerComponent(config: LanguageServerConfig, logLevel: Level) Future.successful(None) } } - val bindBinaryServer = - for { - binding <- module.binaryServer.bind(config.interface, config.dataPort) - _ <- Future { - logger.trace("Server for Binary WebSocket is initialized") - } - } yield binding - val bindSecureBinaryServer: Future[Option[Http.ServerBinding]] = { - config.secureDataPort match { - case Some(port) => - module.binaryServer - .bind(config.interface, port, secure = true) - .map(Some(_)) - case None => - Future.successful(None) - } - } for { - jsonBinding <- bindJsonServer - secureJsonBinding <- bindSecureJsonServer - binaryBinding <- bindBinaryServer - secureBinaryBinding <- bindSecureBinaryServer + jsonBinding <- bindJsonServer + secureJsonBinding <- bindSecureJsonServer _ <- Future { maybeServerCtx = Some( ServerContext( sampler, module, jsonBinding, - secureJsonBinding, - binaryBinding, - secureBinaryBinding + secureJsonBinding ) ) } @@ -149,10 +128,6 @@ class LanguageServerComponent(config: LanguageServerConfig, logLevel: Level) for { _ <- serverContext.jsonBinding.terminate(2.seconds).recover[Any](logError) _ <- Future { logger.info("Terminated JSON connections") } - _ <- serverContext.binaryBinding - .terminate(2.seconds) - .recover[Any](logError) - _ <- Future { logger.info("Terminated binary connections") } _ <- Await .ready( @@ -195,18 +170,14 @@ object LanguageServerComponent { * * @param sampler a sampler gathering the application performance statistics * @param mainModule a main module containing all components of the server - * @param jsonBinding a http binding for rpc protocol + * @param jsonBinding an http binding for rpc protocol * @param secureJsonBinding an optional https binding for rpc protocol - * @param binaryBinding a http binding for data protocol - * @param secureBinaryBinding an optional https binding for data protocol */ case class ServerContext( sampler: MethodsSampler, mainModule: MainModule, jsonBinding: Http.ServerBinding, - secureJsonBinding: Option[Http.ServerBinding], - binaryBinding: Http.ServerBinding, - secureBinaryBinding: Option[Http.ServerBinding] + secureJsonBinding: Option[Http.ServerBinding] ) } diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala b/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala index c00eb5956d4d..10ff51138ba1 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala @@ -12,13 +12,17 @@ import org.enso.distribution.{DistributionManager, Environment, LanguageHome} import org.enso.editions.EditionResolver import org.enso.profiling.events.EventsMonitor import org.enso.editions.updater.EditionManager -import org.enso.jsonrpc.{JsonRpcServer, SecureConnectionConfig} +import org.enso.jsonrpc.{ + JsonRpcServer, + SecureConnectionConfig, + YdocJsonRpcServer +} import org.enso.runner.common.CompilerBasedDependencyExtractor import org.enso.languageserver.capability.CapabilityRouter import org.enso.languageserver.data._ import org.enso.languageserver.effect import org.enso.languageserver.filemanager._ -import org.enso.languageserver.http.server.BinaryWebSocketServer +import org.enso.languageserver.http.server.BinaryYdocServer import org.enso.languageserver.io._ import org.enso.languageserver.libraries._ import org.enso.languageserver.monitoring.{ @@ -56,6 +60,7 @@ import org.enso.common.{ RuntimeOptions } import org.enso.filewatcher.WatcherFactory +import org.enso.languageserver.boot.resource.TruffleContextInitialization import org.enso.logging.utils.akka.AkkaConverter import org.enso.polyglot.RuntimeServerInfo import org.enso.searcher.memory.InMemorySuggestionsRepo @@ -70,6 +75,8 @@ import java.lang.management.ManagementFactory import java.net.URI import java.nio.charset.StandardCharsets import java.time.Clock +import java.util.concurrent.Executors + import scala.concurrent.duration.DurationInt /** A main module containing all components of the server. @@ -457,14 +464,23 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: Level) { private val jsonRpcProtocolFactory = new JsonRpcProtocolFactory + private val truffleContext = { + val contextInitialization = + new TruffleContextInitialization( + system.dispatcher, + builder, + contextSupervisor, + system.eventStream + ) + contextInitialization.initComponent() + contextInitialization.getContext + } private val initializationComponent = ResourcesInitialization( system.eventStream, directoriesConfig, jsonRpcProtocolFactory, suggestionsRepo, - builder, - contextSupervisor, zioRuntime )(system.dispatcher) @@ -501,7 +517,7 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: Level) { val materializer: Materializer = Materializer.createMaterializer(system) val jsonRpcServer = - new JsonRpcServer( + new YdocJsonRpcServer( jsonRpcProtocolFactory, jsonRpcControllerFactory, JsonRpcServer @@ -512,26 +528,39 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: Level) { ), List(healthCheckEndpoint, idlenessEndpoint, renameProjectEndpoint), messagesCallback - )(system, materializer) + )(system) log.trace("Created JSON RPC Server [{}]", jsonRpcServer) - val binaryServer = - new BinaryWebSocketServer( + val binaryChannelCallbacks = + new BinaryYdocServer.BinaryServerCallbacks( InboundMessageDecoder, BinaryEncoder.empty, new BinaryConnectionControllerFactory(fileManager)(system), - BinaryWebSocketServer.Config( - outgoingBufferSize = 100, - lazyMessageTimeout = 10.seconds, - secureConfig = secureConfig - ), - messagesCallback - )(system, materializer) - log.trace("Created Binary WebSocket Server [{}]", binaryServer) + messagesCallback, + truffleContext, + system + ) + log.trace("Created Binary Channel Callbacks [{}]", binaryChannelCallbacks) private val ydoc = { val c = org.enso.languageserver.boot.config.ApplicationConfig.load().ydoc - org.enso.runner.common.YdocServerApi.launchYdocServer(c.hostname, c.port) + val ydocExecutor = Executors.newSingleThreadExecutor(r => { + val t = new Thread(r) + t.setName("Ydoc main thread") + // Ydoc should not prevent JVM from exiting + t.setDaemon(true) + t + }) + ydocExecutor.execute(() => + org.enso.runner.common.YdocServerApi + .launchYdocServer( + c.hostname, + c.port, + jsonRpcServer.yjsChannelCallbacks, + binaryChannelCallbacks + ) + ) + ydocExecutor } log.debug( diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/boot/ResourcesInitialization.scala b/engine/language-server/src/main/scala/org/enso/languageserver/boot/ResourcesInitialization.scala index 67c663fd07f7..93a19fd391b9 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/boot/ResourcesInitialization.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/boot/ResourcesInitialization.scala @@ -9,13 +9,11 @@ import org.enso.languageserver.boot.resource.{ JsonRpcInitialization, RepoInitialization, SequentialResourcesInitialization, - TruffleContextInitialization, ZioRuntimeInitialization } import org.enso.languageserver.data.ProjectDirectoriesConfig import org.enso.languageserver.effect import org.enso.searcher.memory.InMemorySuggestionsRepo -import org.enso.common.ContextFactory import scala.concurrent.ExecutionContextExecutor @@ -30,8 +28,6 @@ object ResourcesInitialization { * @param directoriesConfig configuration of directories that should be created * @param protocolFactory the JSON-RPC protocol factory * @param suggestionsRepo the suggestions repo - * @param truffleContextBuilder the runtime context - * @param truffleContextSupervisor the runtime component supervisor * @param runtime the runtime to run effects * @return the initialization component */ @@ -40,8 +36,6 @@ object ResourcesInitialization { directoriesConfig: ProjectDirectoriesConfig, protocolFactory: ProtocolFactory, suggestionsRepo: InMemorySuggestionsRepo, - truffleContextBuilder: ContextFactory, - truffleContextSupervisor: ComponentSupervisor, runtime: effect.Runtime )(implicit ec: ExecutionContextExecutor): InitializationComponent = { new SequentialResourcesInitialization( @@ -55,12 +49,6 @@ object ResourcesInitialization { directoriesConfig, eventStream, suggestionsRepo - ), - new TruffleContextInitialization( - ec, - truffleContextBuilder, - truffleContextSupervisor, - eventStream ) ) ) diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryWebSocketServer.scala b/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryWebSocketServer.scala index 2f101da3e12e..24a020a7003e 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryWebSocketServer.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryWebSocketServer.scala @@ -2,8 +2,6 @@ package org.enso.languageserver.http.server import akka.NotUsed import akka.actor.{ActorRef, ActorSystem} -import akka.http.scaladsl.model.RemoteAddress -import akka.http.scaladsl.model.StatusCodes.InternalServerError import akka.http.scaladsl.model.ws.{BinaryMessage, Message, TextMessage} import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server.Route @@ -66,26 +64,16 @@ class BinaryWebSocketServer[A, B]( } private val route: Route = - extractClientIP { - case RemoteAddress.Unknown => - complete( - InternalServerError.toString + - "Set akka.http.server.remote-address-header to on" - ) - - case ip: RemoteAddress.IP => - path(config.path) { - get { handleWebSocketMessages(newConnection(ip)) } - } + path(config.path) { + get { handleWebSocketMessages(newConnection()) } } private def newConnection( - ip: RemoteAddress.IP ): Flow[Message, Message, NotUsed] = { - val frontController = factory.createController(ip) + val frontController = factory.createController() - val inboundFlow = createInboundFlow(frontController, ip) + val inboundFlow = createInboundFlow(frontController) val outboundFlow = createOutboundFlow(frontController) Flow.fromSinkAndSource(inboundFlow, outboundFlow) @@ -116,16 +104,14 @@ class BinaryWebSocketServer[A, B]( } private def createInboundFlow( - frontController: ActorRef, - ip: RemoteAddress.IP + frontController: ActorRef ): Sink[Message, NotUsed] = { val flow = Flow[Message] .mapConcat[BinaryMessage] { case msg: TextMessage => logger.warn( - "Received text message [{}] over the data connection [{}].", - msg, - ip + "Received text message [{}] over the data connection.", + msg ) Nil diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryYdocServer.scala b/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryYdocServer.scala new file mode 100644 index 000000000000..977f96e51800 --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/http/server/BinaryYdocServer.scala @@ -0,0 +1,94 @@ +package org.enso.languageserver.http.server + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import com.typesafe.scalalogging.LazyLogging +import org.enso.languageserver.http.server.BinaryWebSocketControlProtocol.OutboundStreamEstablished +import org.enso.languageserver.util.binary.{BinaryDecoder, BinaryEncoder} +import org.enso.ydoc.api.{YjsChannel, YjsChannelCallbacks} +import org.graalvm.polyglot.Context + +import java.lang.foreign.MemorySegment +import java.nio.ByteBuffer + +import scala.util.control.NonFatal + +object BinaryYdocServer { + + /** A web socket server using a binary protocol. + * + * @param decoder a decoder for inbound packets + * @param encoder an encoder for outbound packets + * @param factory creates front controller per a single connection that is responsible for handling all incoming requests + * @param messageCallbacks a list of message callbacks + * @param system an actor system that hosts the server + * @param context a runtime context + * @tparam A a type of messages sent to a connection controller + * @tparam B a type of messages received from a connection controller + */ + final class BinaryServerCallbacks[A, B]( + decoder: BinaryDecoder[A], + encoder: BinaryEncoder[B], + factory: ConnectionControllerFactory, + messageCallbacks: List[ByteBuffer => Unit], + context: Context, + system: ActorSystem + ) extends YjsChannelCallbacks + with LazyLogging { + + override def onConnect(channel: YjsChannel): Unit = { + logger.info("BinaryServerCallbacks.onConnect") + + val incomingMessageHandler = factory.createController() + channel.subscribe(this.onMessage(incomingMessageHandler, _)) + + val outgoingMessageHandler = + system.actorOf( + Props(new OutgoingMessageHandler(channel, encoder)) + ) + incomingMessageHandler ! OutboundStreamEstablished(outgoingMessageHandler) + } + + private def onMessage( + incomingMessageHandler: ActorRef, + message: Object + ): Unit = { + logger.info(s"BinaryServerCallbacks.onMessage ${message.getClass}") + try { + val value = context.asValue(message) + val address = value.asNativePointer() + val segment = + MemorySegment.ofAddress(address).reinterpret(value.getBufferSize()); + val buffer = segment.asByteBuffer() + val decoded = decoder.decode(buffer) + logger.info(s"Received binary message $decoded") + incomingMessageHandler ! decoded + messageCallbacks.foreach(cb => cb(buffer)) + } catch { + case NonFatal(e) => + logger.error( + s"Received unsupported message: ${message.getClass}", + e + ) + } + } + } + + final class OutgoingMessageHandler[B]( + channel: YjsChannel, + encoder: BinaryEncoder[B] + ) extends Actor + with LazyLogging { + + override def receive: Receive = { + case message: B @unchecked => + logger.info(s"Sending binary message $message") + val bytes = encoder.encode(message) + channel.send(bytes.compact()) + case unknown => + logger.error( + s"Sending unsupported message ${unknown.getClass}", + unknown + ) + } + } +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/http/server/ConnectionControllerFactory.scala b/engine/language-server/src/main/scala/org/enso/languageserver/http/server/ConnectionControllerFactory.scala index 89bc1d3beb4e..8653a7f72aa6 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/http/server/ConnectionControllerFactory.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/http/server/ConnectionControllerFactory.scala @@ -1,17 +1,13 @@ package org.enso.languageserver.http.server import akka.actor.ActorRef -import akka.http.scaladsl.model.RemoteAddress -/** A factory of connection controllers. - */ +/** A factory of connection controllers. */ trait ConnectionControllerFactory { /** Creates a connection controller that acts as front controller. * - * @param clientIp a client ip that the connection controller is created for * @return actor ref of created connection controller */ - def createController(clientIp: RemoteAddress.IP): ActorRef - + def createController(): ActorRef } diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionController.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionController.scala index 858749581907..43ea3de53c74 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionController.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionController.scala @@ -1,7 +1,6 @@ package org.enso.languageserver.protocol.binary import akka.actor.{Actor, ActorRef, Props, Stash} -import akka.http.scaladsl.model.RemoteAddress import com.google.flatbuffers.FlatBufferBuilder import com.typesafe.scalalogging.LazyLogging import org.enso.languageserver.event.{ @@ -40,11 +39,8 @@ import scala.concurrent.duration._ /** An actor handling data communications between a single client and the * language server. It acts as a front controller responsible for handling * all incoming requests and dispatching commands. - * - * @param clientIp a client ip that the connection controller is created for */ class BinaryConnectionController( - clientIp: RemoteAddress.IP, fileManager: ActorRef, requestTimeout: FiniteDuration = 10.seconds ) extends Actor @@ -81,9 +77,8 @@ class BinaryConnectionController( val session = BinarySession(clientId, self) context.system.eventStream.publish(BinarySessionInitialized(session)) logger.info( - "Data session initialized for client: {} [{}].", - clientId, - clientIp + "Data session initialized for client: {}.", + clientId ) context.become( connectionEndHandler(Some(session)) @@ -122,7 +117,7 @@ class BinaryConnectionController( maybeDataSession: Option[BinarySession] = None ): Receive = { case ConnectionClosed => - logger.info("Connection closed [{}].", clientIp) + logger.info("Connection closed.") maybeDataSession.foreach(session => context.system.eventStream.publish(BinarySessionTerminated(session)) ) @@ -130,8 +125,7 @@ class BinaryConnectionController( case ConnectionFailed(th) => logger.error( - "An error occurred during processing web socket connection [{}].", - clientIp, + "An error occurred during processing web socket connection.", th ) maybeDataSession.foreach(session => diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionControllerFactory.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionControllerFactory.scala index 3da3e33ef207..4d6440cd2057 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionControllerFactory.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/BinaryConnectionControllerFactory.scala @@ -1,7 +1,6 @@ package org.enso.languageserver.protocol.binary import akka.actor.{ActorRef, ActorSystem, Props} -import akka.http.scaladsl.model.RemoteAddress import org.enso.languageserver.http.server.ConnectionControllerFactory /** A factory for binary connection controllers. @@ -13,8 +12,8 @@ class BinaryConnectionControllerFactory(fileManager: ActorRef)(implicit ) extends ConnectionControllerFactory { /** @inheritdoc */ - override def createController(clientIp: RemoteAddress.IP): ActorRef = { - system.actorOf(Props(new BinaryConnectionController(clientIp, fileManager))) + override def createController(): ActorRef = { + system.actorOf(Props(new BinaryConnectionController(fileManager))) } } diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/InboundMessageDecoder.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/InboundMessageDecoder.scala index 473a38835240..d0dafca90b26 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/InboundMessageDecoder.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/binary/InboundMessageDecoder.scala @@ -2,7 +2,6 @@ package org.enso.languageserver.protocol.binary import java.nio.ByteBuffer -import org.enso.languageserver.protocol.binary.{InboundMessage, InboundPayload} import org.enso.languageserver.util.binary.DecodingFailure.{ DataCorrupted, EmptyPayload, @@ -10,8 +9,7 @@ import org.enso.languageserver.util.binary.DecodingFailure.{ } import org.enso.languageserver.util.binary.{BinaryDecoder, DecodingFailure} -/** A decoder for an [[InboundMessage]]. - */ +/** A decoder for an [[InboundMessage]]. */ object InboundMessageDecoder extends BinaryDecoder[InboundMessage] { /** @inheritdoc */ diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala index a2c62d837a2e..7b1524843644 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala @@ -163,7 +163,7 @@ class JsonConnectionController( } override def receive: Receive = LoggingReceive { - case JsonRpcServer.WebConnect(webActor, _) => + case JsonRpcServer.WebConnect(webActor) => unstashAll() context.become(connected(webActor)) case _ => stash() @@ -216,7 +216,7 @@ class JsonConnectionController( case Request(_, id, _) => sender() ! ResponseError(Some(id), SessionNotInitialisedError) - case MessageHandler.Disconnected(_) => + case MessageHandler.Disconnected() => context.stop(self) } @@ -363,7 +363,7 @@ class JsonConnectionController( sender() ! ResponseError(Some(id), SessionAlreadyInitialisedError) } - case MessageHandler.Disconnected(_) => + case MessageHandler.Disconnected() => logger.info("Session terminated [{}].", rpcSession.clientId) context.system.eventStream.publish(JsonSessionTerminated(rpcSession)) context.stop(self) diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala index 0c3f6df8c648..d1f6fe1e77d8 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala @@ -3,7 +3,6 @@ import java.nio.ByteBuffer import java.nio.file.Files import java.util.UUID import akka.actor.{ActorRef, Props} -import akka.http.scaladsl.model.RemoteAddress import com.google.flatbuffers.FlatBufferBuilder import org.apache.commons.io.FileUtils import org.enso.runner.common.ProfilingConfig @@ -76,7 +75,7 @@ abstract class BaseBinaryServerTest extends BinaryServerTestKit { } override def connectionControllerFactory: ConnectionControllerFactory = { - (clientIp: RemoteAddress.IP) => + () => { val testExecutor = ExecutionContext.fromExecutor(threadPool) val zioRuntime = new ExecutionContextRuntime(testExecutor) @@ -100,7 +99,7 @@ abstract class BaseBinaryServerTest extends BinaryServerTestKit { val controller = system.actorOf( - Props(new BinaryConnectionController(clientIp, fileManager)) + Props(new BinaryConnectionController(fileManager)) ) lastConnectionController = controller controller diff --git a/engine/runner-common/src/main/java/module-info.java b/engine/runner-common/src/main/java/module-info.java index 3614cdaa6b40..af138a308706 100644 --- a/engine/runner-common/src/main/java/module-info.java +++ b/engine/runner-common/src/main/java/module-info.java @@ -8,6 +8,7 @@ requires org.enso.polyglot.api; requires org.slf4j; requires org.enso.logging.utils; + requires org.enso.ydoc.api; requires scala.library; exports org.enso.runner.common; diff --git a/engine/runner-common/src/main/java/org/enso/runner/common/YdocServerApi.java b/engine/runner-common/src/main/java/org/enso/runner/common/YdocServerApi.java index 348c29f99b2a..decb17096077 100644 --- a/engine/runner-common/src/main/java/org/enso/runner/common/YdocServerApi.java +++ b/engine/runner-common/src/main/java/org/enso/runner/common/YdocServerApi.java @@ -3,9 +3,15 @@ import java.io.IOException; import java.net.URISyntaxException; import java.util.ServiceLoader; +import org.enso.ydoc.api.YjsChannelCallbacks; public abstract class YdocServerApi { - public static AutoCloseable launchYdocServer(String hostname, int port) + + public static AutoCloseable launchYdocServer( + String hostname, + int port, + YjsChannelCallbacks jsonChannelCallbacks, + YjsChannelCallbacks binaryChannelCallbacks) throws WrongOption, IOException, URISyntaxException { var loader = YdocServerApi.class.getClassLoader(); var it = ServiceLoader.load(YdocServerApi.class, loader).iterator(); @@ -13,9 +19,13 @@ public static AutoCloseable launchYdocServer(String hostname, int port) throw new WrongOption("No Ydoc server implementation found"); } var impl = it.next(); - return impl.runYdocServer(hostname, port); + return impl.runYdocServer(hostname, port, jsonChannelCallbacks, binaryChannelCallbacks); } - protected abstract AutoCloseable runYdocServer(String hostname, int port) + protected abstract AutoCloseable runYdocServer( + String hostname, + int port, + YjsChannelCallbacks jsonChannelCallbacks, + YjsChannelCallbacks binaryChannelCallbacks) throws WrongOption, IOException, URISyntaxException; } diff --git a/lib/java/jvm-interop/src/main/java/org/enso/jvm/interop/impl/OtherInteropType.java b/lib/java/jvm-interop/src/main/java/org/enso/jvm/interop/impl/OtherInteropType.java index f60187d4303a..d8aa64364d55 100644 --- a/lib/java/jvm-interop/src/main/java/org/enso/jvm/interop/impl/OtherInteropType.java +++ b/lib/java/jvm-interop/src/main/java/org/enso/jvm/interop/impl/OtherInteropType.java @@ -4,6 +4,7 @@ import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.library.Message; +import com.oracle.truffle.api.strings.TruffleString; import java.io.IOException; import java.math.BigInteger; import java.nio.charset.StandardCharsets; @@ -600,4 +601,25 @@ protected Duration readObject(Persistance.Input in) throws IOException, ClassNot return Duration.ofSeconds(s, n); } } + + @Persistable(id = 127) + static final class PersistTruffleString extends Persistance { + + public PersistTruffleString() { + super(TruffleString.class, true, 127); + } + + @Override + protected void writeObject(TruffleString obj, Persistance.Output out) throws IOException { + var s = obj.toJavaStringUncached(); + out.writeUTF(s); + } + + @Override + protected TruffleString readObject(Persistance.Input in) + throws IOException, ClassNotFoundException { + var s = in.readUTF(); + return TruffleString.fromJavaStringUncached(s, TruffleString.Encoding.UTF_8); + } + } } diff --git a/lib/java/jvm-interop/src/test/java/org/enso/jvm/interop/impl/OtherJvmJavaScriptTest.java b/lib/java/jvm-interop/src/test/java/org/enso/jvm/interop/impl/OtherJvmJavaScriptTest.java new file mode 100644 index 000000000000..45327b4dda1b --- /dev/null +++ b/lib/java/jvm-interop/src/test/java/org/enso/jvm/interop/impl/OtherJvmJavaScriptTest.java @@ -0,0 +1,105 @@ +package org.enso.jvm.interop.impl; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.oracle.truffle.api.TruffleLanguage; +import org.enso.jvm.channel.Channel; +import org.enso.test.utils.ContextUtils; +import org.graalvm.polyglot.Context; +import org.graalvm.polyglot.Value; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; + +public class OtherJvmJavaScriptTest { + @ClassRule + public static final ContextUtils ctx = ContextUtils.newBuilder("host").assertGC(false).build(); + + private static Channel CHANNEL; + + @BeforeClass + public static void initializeChannel() { + System.setProperty("org.enso.jvm.interop.limit", "" + Integer.MAX_VALUE); + CHANNEL = Channel.create(null, OtherJvmPool.class); + CHANNEL + .getConfig() + .onEnterLeave( + FakeLanguage.class, + null, + (__) -> { + ctx.context().enter(); + return null; + }, + (__, ___) -> { + ctx.context().leave(); + }); + } + + @Test + public void wrapTruffleString() throws Exception { + var testClassValue = loadOtherJvmClass(OtherJvmJavaScriptTest.class.getName()); + assertOtherJvmObject("Represents clazz from the other JVM", testClassValue); + + var result = + new ResultCallbacks() { + private Object value; + + @Override + public void onMessage(Object o) { + this.value = o; + } + }; + + var returnedResult = testClassValue.invokeMember("multiString", "Hello", 3, result); + // assertOtherJvmObject("Represents object from the other JVM", otherValue); + + assertEquals("HelloHelloHello", returnedResult.asString()); + assertEquals("HelloHelloHello", result.value.toString()); + } + + public static String multiString(String txt, int count, ResultCallbacks onResult) { + try (var jsCtx = Context.newBuilder("js").build()) { + var fn = + jsCtx.eval( + "js", + """ + (function(txt, count, onResult) { + let sb = ""; + for (let i = 0; i < count; i++) { + sb = sb + txt; + } + onResult.onMessage(sb); + return sb; + }) + """); + var res = fn.execute(txt, count, onResult).asString(); + return res; + } + } + + private static Value loadOtherJvmClass(String name) throws Exception { + var msg = new OtherJvmMessage.LoadClass(name); + var raw = CHANNEL.execute(OtherJvmResult.class, msg).value(null); + if (raw instanceof OtherJvmObject other) { + assertTrue(other.assertChannel(CHANNEL)); + } + var value = ctx.asValue(raw); + return value; + } + + private static void assertOtherJvmObject(String msg, Value value) { + var unwrap = ctx.unwrapValue(value); + if (unwrap instanceof OtherJvmObject) { + return; + } + fail(msg + " but got: " + unwrap); + } + + public static interface ResultCallbacks { + public void onMessage(Object o); + } + + private abstract static class FakeLanguage extends TruffleLanguage {} +} diff --git a/lib/java/ydoc-api/src/main/java/module-info.java b/lib/java/ydoc-api/src/main/java/module-info.java new file mode 100644 index 000000000000..9bda9ae25064 --- /dev/null +++ b/lib/java/ydoc-api/src/main/java/module-info.java @@ -0,0 +1,3 @@ +module org.enso.ydoc.api { + exports org.enso.ydoc.api; +} diff --git a/lib/java/ydoc-api/src/main/java/org/enso/ydoc/api/YjsChannel.java b/lib/java/ydoc-api/src/main/java/org/enso/ydoc/api/YjsChannel.java new file mode 100644 index 000000000000..b1dd72380df4 --- /dev/null +++ b/lib/java/ydoc-api/src/main/java/org/enso/ydoc/api/YjsChannel.java @@ -0,0 +1,9 @@ +package org.enso.ydoc.api; + +import java.util.function.Consumer; + +public interface YjsChannel { + public void send(Object message); + + public void subscribe(Consumer messageHandler); +} diff --git a/lib/java/ydoc-api/src/main/java/org/enso/ydoc/api/YjsChannelCallbacks.java b/lib/java/ydoc-api/src/main/java/org/enso/ydoc/api/YjsChannelCallbacks.java new file mode 100644 index 000000000000..1564101e7743 --- /dev/null +++ b/lib/java/ydoc-api/src/main/java/org/enso/ydoc/api/YjsChannelCallbacks.java @@ -0,0 +1,6 @@ +package org.enso.ydoc.api; + +public interface YjsChannelCallbacks { + + public void onConnect(YjsChannel channel); +} diff --git a/lib/java/ydoc-polyfill/src/main/java/org/enso/ydoc/polyfill/web/WebEnvironment.java b/lib/java/ydoc-polyfill/src/main/java/org/enso/ydoc/polyfill/web/WebEnvironment.java index 05f7f619fa3f..5392fb063830 100644 --- a/lib/java/ydoc-polyfill/src/main/java/org/enso/ydoc/polyfill/web/WebEnvironment.java +++ b/lib/java/ydoc-polyfill/src/main/java/org/enso/ydoc/polyfill/web/WebEnvironment.java @@ -60,7 +60,10 @@ public static Context.Builder createContext() { public static Context.Builder createContext(HostAccess hostAccess) { var contextBuilder = - Context.newBuilder("js").allowHostAccess(hostAccess).allowExperimentalOptions(true); + Context.newBuilder("js") + .allowHostAccess(hostAccess) + .allowExperimentalOptions(true) + .allowHostClassLookup(className -> className.equals("java.nio.ByteBuffer")); var inspectPort = Integer.getInteger("inspectPort", -1); if (inspectPort > 0) { diff --git a/lib/java/ydoc-polyfill/src/main/resources/org/enso/ydoc/polyfill/web/event-target.js b/lib/java/ydoc-polyfill/src/main/resources/org/enso/ydoc/polyfill/web/event-target.js index d69e03ae3f8d..63c89bbd8f57 100644 --- a/lib/java/ydoc-polyfill/src/main/resources/org/enso/ydoc/polyfill/web/event-target.js +++ b/lib/java/ydoc-polyfill/src/main/resources/org/enso/ydoc/polyfill/web/event-target.js @@ -37,6 +37,27 @@ } } + class CloseEvent extends Event { + + constructor(type, options) { + super(type, options) + } + } + + class MessageEvent extends Event { + + constructor(type, options) { + super(type, options) + } + } + + class ErrorEvent extends Event { + + constructor(type, options) { + super(type, options) + } + } + class EventTarget { #eventStore; @@ -66,6 +87,12 @@ globalThis.Event = Event; + globalThis.CloseEvent = CloseEvent; + + globalThis.MessageEvent = MessageEvent; + + globalThis.ErrorEvent = ErrorEvent; + globalThis.EventTarget = EventTarget; }) diff --git a/lib/java/ydoc-polyfill/src/test/java/org/enso/ydoc/api/CallbacksTest.java b/lib/java/ydoc-polyfill/src/test/java/org/enso/ydoc/api/CallbacksTest.java new file mode 100644 index 000000000000..c2d232e6aaf8 --- /dev/null +++ b/lib/java/ydoc-polyfill/src/test/java/org/enso/ydoc/api/CallbacksTest.java @@ -0,0 +1,134 @@ +package org.enso.ydoc.api; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.enso.ydoc.polyfill.ExecutorSetup; +import org.enso.ydoc.polyfill.web.WebEnvironment; +import org.graalvm.polyglot.Context; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class CallbacksTest extends ExecutorSetup { + + private Context context; + + public CallbacksTest() {} + + public final class TestCallbacks implements YjsChannelCallbacks { + + private Consumer handler; + + public TestCallbacks(Consumer handler) { + this.handler = handler; + } + + @Override + public void onConnect(YjsChannel channel) { + this.handler.accept(channel); + } + } + + @Before + public void setup() throws Exception { + super.setup(); + + var hostAccess = + WebEnvironment.defaultHostAccess + // allowImplementations is required to call methods on JS objects from Java, + // i.e. to call `YjsChannel::send` in the `TestCallbacks::onConnect` method + .allowImplementations(YjsChannel.class) + // public access is required to recognize Java lambdas passed to + // `YjsChannel::subscribe` method as JS functions. + .allowPublicAccess(true) + .allowAccess(TestCallbacks.class.getDeclaredMethod("onConnect", YjsChannel.class)) + .allowAccess(AtomicReference.class.getDeclaredMethod("set", Object.class)) + .build(); + var contextBuilder = WebEnvironment.createContext(hostAccess); + + context = CompletableFuture.supplyAsync(contextBuilder::build, executor).get(); + } + + @After + public void tearDown() throws InterruptedException { + super.tearDown(); + context.close(); + } + + @Test + public void onConnectSend() throws Exception { + var res = new AtomicReference<>(); + var code = + """ + class YjsChannel { + send(message) { + res.set(message); + } + } + + var channel = new YjsChannel(); + callbacks.onConnect(channel); + """; + + var callbacks = new TestCallbacks((channel) -> channel.send("Hello!")); + context.getBindings("js").putMember("callbacks", callbacks); + context.getBindings("js").putMember("res", res); + + CompletableFuture.runAsync(() -> context.eval("js", code), executor).get(); + + Assert.assertEquals("Hello!", res.get()); + } + + @Test + public void onConnectSubscribe() throws Exception { + var res = new AtomicReference<>(); + var code = + """ + class YjsChannel { + subscribe(messageHandler) { + messageHandler('World!'); + } + } + + var channel = new YjsChannel(); + callbacks.onConnect(channel); + """; + + var callbacks = + new TestCallbacks((channel) -> channel.subscribe((message) -> res.set(message))); + context.getBindings("js").putMember("callbacks", callbacks); + + CompletableFuture.runAsync(() -> context.eval("js", code), executor).get(); + + Assert.assertEquals("World!", res.get()); + } + + @Test + public void onConnectSubscribeBuffer() throws Exception { + var res = new AtomicReference<>(); + var code = + """ + class YjsChannel { + subscribe(messageHandler) { + var arr = new Uint8Array([0, 128, 255]); + messageHandler(arr.buffer); + } + } + + var channel = new YjsChannel(); + callbacks.onConnect(channel); + """; + + var callbacks = + new TestCallbacks((channel) -> channel.subscribe((message) -> res.set(message))); + context.getBindings("js").putMember("callbacks", callbacks); + + CompletableFuture.runAsync(() -> context.eval("js", code), executor).get(); + var value = context.asValue(res.get()); + var arr = value.as(byte[].class); + + Assert.assertArrayEquals(new byte[] {0, -128, -1}, arr); + } +} diff --git a/lib/java/ydoc-server-registration/src/main/java/module-info.java b/lib/java/ydoc-server-registration/src/main/java/module-info.java index f93e0451c5b0..c79594c69b41 100644 --- a/lib/java/ydoc-server-registration/src/main/java/module-info.java +++ b/lib/java/ydoc-server-registration/src/main/java/module-info.java @@ -4,6 +4,7 @@ requires org.graalvm.polyglot; requires org.enso.jvm.interop; requires org.enso.jvm.channel; + requires org.enso.ydoc.api; requires static org.graalvm.nativeimage; // only register service, otherwise the module isn't needed diff --git a/lib/java/ydoc-server-registration/src/main/java/org/enso/ydoc/server/registration/YdocServerImpl.java b/lib/java/ydoc-server-registration/src/main/java/org/enso/ydoc/server/registration/YdocServerImpl.java index 1026229712c5..9987fa3ded2b 100644 --- a/lib/java/ydoc-server-registration/src/main/java/org/enso/ydoc/server/registration/YdocServerImpl.java +++ b/lib/java/ydoc-server-registration/src/main/java/org/enso/ydoc/server/registration/YdocServerImpl.java @@ -6,14 +6,18 @@ import org.enso.jvm.interop.api.OtherJvmClassLoader; import org.enso.runner.common.WrongOption; import org.enso.runner.common.YdocServerApi; +import org.enso.ydoc.api.YjsChannelCallbacks; import org.graalvm.nativeimage.ImageInfo; -import org.graalvm.polyglot.proxy.ProxyArray; public final class YdocServerImpl extends YdocServerApi { public YdocServerImpl() {} @Override - protected AutoCloseable runYdocServer(String hostname, int port) + protected AutoCloseable runYdocServer( + String hostname, + int port, + YjsChannelCallbacks jsonChannelCallbacks, + YjsChannelCallbacks binaryChannelCallbacks) throws WrongOption, IOException, URISyntaxException { // the following shall invoke: // return launch(hostname, port); @@ -40,8 +44,7 @@ protected AutoCloseable runYdocServer(String hostname, int port) var fqn = "org.enso.ydoc.server.Main"; var impl = loader.loadClass(fqn); assert impl != null; - var arr = ProxyArray.fromArray(hostname, "" + port); - impl.invokeMember("main", arr); + impl.invokeMember("launch", hostname, port + "", jsonChannelCallbacks, binaryChannelCallbacks); return loader; } } diff --git a/lib/java/ydoc-server/src/main/java/module-info.java b/lib/java/ydoc-server/src/main/java/module-info.java index c1cabce7b408..0149780accfe 100644 --- a/lib/java/ydoc-server/src/main/java/module-info.java +++ b/lib/java/ydoc-server/src/main/java/module-info.java @@ -1,5 +1,6 @@ module org.enso.ydoc.server { requires io.helidon.common; + requires org.enso.ydoc.api; requires org.enso.ydoc.polyfill; requires org.graalvm.polyglot; requires org.slf4j; diff --git a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Main.java b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Main.java index c31bcd582125..8187ce56498d 100644 --- a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Main.java +++ b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Main.java @@ -1,59 +1,47 @@ package org.enso.ydoc.server; import java.io.IOException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Semaphore; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.enso.ydoc.api.YjsChannel; +import org.enso.ydoc.api.YjsChannelCallbacks; +import org.enso.ydoc.polyfill.web.WebEnvironment; public final class Main { - private static final Logger log = LoggerFactory.getLogger(Main.class); - - private static final String ENSO_YDOC_HOST = "ENSO_YDOC_HOST"; - private static final String ENSO_YDOC_PORT = "ENSO_YDOC_PORT"; - - private static final Semaphore lock = new Semaphore(0); private Main() {} - public static void main(String[] args) throws Exception { - System.setProperty( - "helidon.serialFilter.pattern", - "javax.management.**;java.lang.**;java.rmi.**;javax.security.auth.Subject;!*"); - - if (args.length == 2) { - var then = System.currentTimeMillis(); - var hostname = args[0]; - var port = args[1]; - launch(hostname, port); - - var now = System.currentTimeMillis(); - var took = now - then; - log.debug("Ydoc server at {}:{} started in {} ms", hostname, port, took); - } else { - var hostname = System.getenv(ENSO_YDOC_HOST); - var port = System.getenv(ENSO_YDOC_PORT); - try (var ydoc = launch(hostname, port)) { - lock.acquire(); - } - } + public static void main(String[] args) { + // main method declaration is required to build the native library } - private static AutoCloseable launch(String ydocHost, String ydocPort) throws IOException { - try { - var builder = Ydoc.builder(); - if (ydocHost != null) { - builder.hostname(ydocHost); - } - if (ydocPort != null) { - var port = Integer.parseInt(ydocPort); - builder.port(port); - } - var ydoc = builder.build(); - ydoc.start(); - return ydoc; - } catch (ExecutionException | InterruptedException ex) { - throw new IOException(ex); + public static AutoCloseable launch( + String ydocHost, + String ydocPort, + YjsChannelCallbacks jsonChannelCallbacks, + YjsChannelCallbacks binaryChannelCallbacks) + throws IOException { + var builder = Ydoc.builder(); + if (ydocHost != null) { + builder.hostname(ydocHost); + } + if (ydocPort != null) { + var port = Integer.parseInt(ydocPort); + builder.port(port); + } + if (jsonChannelCallbacks != null) { + builder.jsonChannelCallbacks(jsonChannelCallbacks); + } + if (binaryChannelCallbacks != null) { + builder.binaryChannelCallbacks(binaryChannelCallbacks); } + var hostAccess = + WebEnvironment.defaultHostAccess + // allowImplementations is required to call methods on JS objects from Java, i.e. to + // call methods on `YjsChannel` object returned from JS + .allowImplementations(YjsChannel.class) + .allowPublicAccess(true); + builder.hostAccessBuilder(hostAccess); + var ydoc = builder.build(); + ydoc.start(); + return ydoc; } } diff --git a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Ydoc.java b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Ydoc.java index cde1853d73ef..f3cac47ebfb6 100644 --- a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Ydoc.java +++ b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/Ydoc.java @@ -1,14 +1,13 @@ package org.enso.ydoc.server; import java.io.IOException; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.enso.ydoc.api.YjsChannel; +import org.enso.ydoc.api.YjsChannelCallbacks; import org.enso.ydoc.polyfill.ParserPolyfill; import org.enso.ydoc.polyfill.web.WebEnvironment; import org.graalvm.polyglot.Context; +import org.graalvm.polyglot.HostAccess; import org.graalvm.polyglot.Source; import org.graalvm.polyglot.io.IOAccess; @@ -17,25 +16,32 @@ public final class Ydoc implements AutoCloseable { private static final String YDOC_EXECUTOR_THREAD_NAME = "Ydoc executor thread"; private static final String YDOC_PATH = "ydoc.cjs"; - private final ScheduledExecutorService executor; + private final YdocScheduledExecutorService executor; private final ParserPolyfill parser; private final Context.Builder contextBuilder; private final String hostname; private final int port; + private final YjsChannelCallbacks jsonChannelCallbacks; + private final YjsChannelCallbacks binaryChannelCallbacks; + private final AtomicBoolean running = new AtomicBoolean(false); private Context context; private Ydoc( - ScheduledExecutorService executor, + YdocScheduledExecutorService executor, ParserPolyfill parser, Context.Builder contextBuilder, String hostname, - int port) { + int port, + YjsChannelCallbacks jsonChannelCallbacks, + YjsChannelCallbacks binaryChannelCallbacks) { this.executor = executor; this.parser = parser; this.contextBuilder = contextBuilder; this.hostname = hostname; this.port = port; + this.jsonChannelCallbacks = jsonChannelCallbacks; + this.binaryChannelCallbacks = binaryChannelCallbacks; } public static final class Builder { @@ -43,15 +49,38 @@ public static final class Builder { private static final String DEFAULT_HOSTNAME = "localhost"; private static final int DEFAULT_PORT = 5976; - private ScheduledExecutorService executor; + private YdocScheduledExecutorService executor; private ParserPolyfill parser; private Context.Builder contextBuilder; + private HostAccess.Builder hostAccessBuilder; private String hostname; private int port = -1; + private YjsChannelCallbacks jsonChannelCallbacks; + private YjsChannelCallbacks binaryChannelCallbacks; private Builder() {} - public Builder executor(ScheduledExecutorService executor) { + public static final class DelegateYjsChannelCallbacks implements YjsChannelCallbacks { + private final String name; + private final YjsChannelCallbacks delegate; + + DelegateYjsChannelCallbacks(String name, YjsChannelCallbacks delegate) { + this.name = name; + this.delegate = delegate; + } + + @HostAccess.Export + @Override + public void onConnect(YjsChannel channel) { + System.err.println("Enter onConnect[" + name + "] with " + channel + " for " + delegate); + if (delegate != null) { + delegate.onConnect(channel); + } + System.err.println("Exit onConnect[" + name + "] with " + channel); + } + } + + public Builder executor(YdocScheduledExecutorService executor) { this.executor = executor; return this; } @@ -61,6 +90,11 @@ public Builder parser(ParserPolyfill parser) { return this; } + public Builder hostAccessBuilder(HostAccess.Builder hostAccessBuilder) { + this.hostAccessBuilder = hostAccessBuilder; + return this; + } + public Builder contextBuilder(Context.Builder contextBuilder) { this.contextBuilder = contextBuilder; return this; @@ -76,23 +110,32 @@ public Builder port(int port) { return this; } + public Builder jsonChannelCallbacks(YjsChannelCallbacks callbacks) { + this.jsonChannelCallbacks = callbacks; + return this; + } + + public Builder binaryChannelCallbacks(YjsChannelCallbacks callbacks) { + this.binaryChannelCallbacks = callbacks; + return this; + } + public Ydoc build() { if (executor == null) { - executor = - Executors.newSingleThreadScheduledExecutor( - r -> { - var t = new Thread(r); - t.setName(YDOC_EXECUTOR_THREAD_NAME); - return t; - }); + executor = new YdocScheduledExecutorService(); } if (parser == null) { parser = new ParserPolyfill(); } + if (hostAccessBuilder == null) { + hostAccessBuilder = WebEnvironment.defaultHostAccess; + } + if (contextBuilder == null) { - contextBuilder = WebEnvironment.createContext().allowIO(IOAccess.ALL); + contextBuilder = + WebEnvironment.createContext(hostAccessBuilder.build()).allowIO(IOAccess.ALL); } if (hostname == null) { @@ -103,7 +146,14 @@ public Ydoc build() { port = DEFAULT_PORT; } - return new Ydoc(executor, parser, contextBuilder, hostname, port); + return new Ydoc( + executor, + parser, + contextBuilder, + hostname, + port, + new DelegateYjsChannelCallbacks("JSON", jsonChannelCallbacks), + new DelegateYjsChannelCallbacks("binary", binaryChannelCallbacks)); } } @@ -115,7 +165,15 @@ public Context.Builder getContextBuilder() { return contextBuilder; } - public void start() throws ExecutionException, InterruptedException, IOException { + public YjsChannelCallbacks getJsonChannelCallbacksSynchronized() { + return new YjsCallbacksSynchronized(jsonChannelCallbacks, executor); + } + + public YjsChannelCallbacks getBinaryChannelCallbacksSynchronized() { + return new YjsCallbacksSynchronized(binaryChannelCallbacks, executor); + } + + public void start() throws IOException { var ydoc = Main.class.getResource(YDOC_PATH); if (ydoc == null) { throw new AssertionError( @@ -125,30 +183,71 @@ public void start() throws ExecutionException, InterruptedException, IOException } var ydocJs = Source.newBuilder("js", ydoc).build(); - context = - CompletableFuture.supplyAsync( - () -> { - var ctx = contextBuilder.build(); - WebEnvironment.initialize(ctx, executor); - parser.initialize(ctx); + running.set(true); + + // Submit initialization task + var initFuture = + executor.submit( + () -> { + var ctx = contextBuilder.build(); + WebEnvironment.initialize(ctx, executor); + parser.initialize(ctx); + + var bindings = ctx.getBindings("js"); + bindings.putMember("YDOC_HOST", hostname); + bindings.putMember("YDOC_PORT", port); + bindings.putMember( + "YDOC_JSON_CHANNEL_CALLBACKS", getJsonChannelCallbacksSynchronized()); + bindings.putMember( + "YDOC_BINARY_CHANNEL_CALLBACKS", getBinaryChannelCallbacksSynchronized()); + bindings.putMember("YDOC_LS_DEBUG", "false"); + + ctx.eval(ydocJs); + + return ctx; + }); + + while (!initFuture.isDone()) { + executor.processPendingTasks(); + try { + long delay = executor.getNextTaskDelayNanos(); + executor.waitForTasks(delay); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } - var bindings = ctx.getBindings("js"); - bindings.putMember("YDOC_HOST", hostname); - bindings.putMember("YDOC_PORT", port); - bindings.putMember("YDOC_LS_DEBUG", "false"); + try { + context = initFuture.get(); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize Ydoc", e); + } - ctx.eval(ydocJs); + runEventLoopBlocking(); + } - return ctx; - }, - executor) - .get(); + /** + * Runs the event loop continuously until {@link #close()} is called. This method blocks and + * should typically be run in a dedicated thread. + */ + public void runEventLoopBlocking() { + while (running.get()) { + executor.processPendingTasks(); + try { + long delay = executor.getNextTaskDelayNanos(); + executor.waitForTasks(delay); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } + } } @Override public void close() throws Exception { - executor.shutdownNow(); - executor.awaitTermination(3, TimeUnit.SECONDS); + running.set(false); + executor.shutdown(); if (context != null) { context.close(true); } diff --git a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YdocScheduledExecutorService.java b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YdocScheduledExecutorService.java new file mode 100644 index 000000000000..57a1d52eddf5 --- /dev/null +++ b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YdocScheduledExecutorService.java @@ -0,0 +1,650 @@ +package org.enso.ydoc.server; + +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.PriorityQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A single-threaded execution service that processes tasks on the thread where it was created. + * + *

This service maintains an event queue and executes tasks when {@link #processPendingTasks()} + * is called from the owner thread. It supports both immediate task execution and scheduled tasks + * with delays. + * + *

Thread Safety: This service is thread-safe for submitting tasks, but {@link + * #processPendingTasks()} must only be called from the owner thread. + */ +public final class YdocScheduledExecutorService implements ScheduledExecutorService { + + private final long ownerThreadId; + private final ConcurrentLinkedQueue immediateTasks; + private final Object scheduledTasksLock = new Object(); + private final PriorityQueue scheduledTasks; + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final Object waitLock = new Object(); + + /** Creates a new execution service bound to the current thread. */ + public YdocScheduledExecutorService() { + this.ownerThreadId = Thread.currentThread().threadId(); + this.immediateTasks = new ConcurrentLinkedQueue<>(); + this.scheduledTasks = new PriorityQueue<>(); + } + + /** + * Submits a task for immediate execution (internal method). + * + *

The task will be executed on the next call to {@link #processPendingTasks()} from the owner + * thread. + * + * @param task the task to execute + * @throws IllegalStateException if the service has been shut down + */ + private void submitInternal(Runnable task) { + if (shutdown.get()) { + throw new IllegalStateException("Service has been shut down"); + } + immediateTasks.offer(task); + synchronized (waitLock) { + waitLock.notifyAll(); + } + } + + @Override + public Future submit(Runnable task) { + submitInternal(task); + return java.util.concurrent.CompletableFuture.completedFuture(null); + } + + @Override + public Future submit(Runnable task, T result) { + submitInternal(task); + return java.util.concurrent.CompletableFuture.completedFuture(result); + } + + @Override + public Future submit(Callable task) { + var javaFuture = new java.util.concurrent.CompletableFuture(); + submitInternal( + () -> { + try { + javaFuture.complete(task.call()); + } catch (Throwable t) { + javaFuture.completeExceptionally(t); + } + }); + return javaFuture; + } + + @Override + public ScheduledFuture schedule(Runnable task, long delay, TimeUnit unit) { + if (shutdown.get()) { + throw new IllegalStateException("Service has been shut down"); + } + long executeAtNanos = System.nanoTime() + unit.toNanos(delay); + var cancellableTask = new CancellableTask(task, executeAtNanos); + synchronized (scheduledTasksLock) { + scheduledTasks.offer(new ScheduledTask(cancellableTask, executeAtNanos)); + } + synchronized (waitLock) { + waitLock.notifyAll(); + } + return cancellableTask; + } + + /** + * Schedules a task to execute after the specified duration. + * + * @param task the task to execute + * @param delay the delay before execution + * @throws IllegalStateException if the service has been shut down + */ + public void schedule(Runnable task, Duration delay) { + schedule(task, delay.toNanos(), TimeUnit.NANOSECONDS); + } + + @Override + public ScheduledFuture schedule(Callable task, long delay, TimeUnit unit) { + if (shutdown.get()) { + throw new IllegalStateException("Service has been shut down"); + } + long executeAtNanos = System.nanoTime() + unit.toNanos(delay); + var callableTask = new CallableScheduledFuture<>(task, executeAtNanos); + var wrapper = + new CancellableTask( + () -> { + try { + callableTask.complete(task.call()); + } catch (Throwable t) { + callableTask.completeExceptionally(t); + } + }, + executeAtNanos); + synchronized (scheduledTasksLock) { + scheduledTasks.offer(new ScheduledTask(wrapper, executeAtNanos)); + } + synchronized (waitLock) { + waitLock.notifyAll(); + } + return callableTask; + } + + @Override + public ScheduledFuture scheduleAtFixedRate( + Runnable task, long initialDelay, long period, TimeUnit unit) { + if (shutdown.get()) { + throw new IllegalStateException("Service has been shut down"); + } + + var repeatingTask = new RepeatingTask(task, unit.toNanos(period)); + long executeAtNanos = System.nanoTime() + unit.toNanos(initialDelay); + + var cancellableTask = new CancellableTask(repeatingTask, executeAtNanos); + synchronized (scheduledTasksLock) { + scheduledTasks.offer(new ScheduledTask(cancellableTask, executeAtNanos)); + } + synchronized (waitLock) { + waitLock.notifyAll(); + } + return cancellableTask; + } + + @Override + public ScheduledFuture scheduleWithFixedDelay( + Runnable task, long initialDelay, long delay, TimeUnit unit) { + // For our use case, fixed delay is similar to fixed rate + return scheduleAtFixedRate(task, initialDelay, delay, unit); + } + + /** A task that reschedules itself after execution. */ + private final class RepeatingTask implements Runnable { + private final Runnable task; + private final long periodNanos; + + RepeatingTask(Runnable task, long periodNanos) { + this.task = task; + this.periodNanos = periodNanos; + } + + @Override + public void run() { + try { + task.run(); + } catch (Throwable t) { + handleUncaughtException(t); + } + + // Reschedule for next execution + if (!shutdown.get()) { + long nextExecutionNanos = System.nanoTime() + periodNanos; + synchronized (scheduledTasksLock) { + scheduledTasks.offer(new ScheduledTask(this, nextExecutionNanos)); + } + synchronized (waitLock) { + waitLock.notifyAll(); + } + } + } + } + + /** + * A cancellable task wrapper that implements ScheduledFuture interface for compatibility with + * ScheduledExecutorService APIs. + */ + private static final class CancellableTask + implements Runnable, java.util.concurrent.ScheduledFuture { + private final Runnable task; + private final AtomicBoolean cancelled = new AtomicBoolean(false); + private final long executeAtNanos; + + CancellableTask(Runnable task, long executeAtNanos) { + this.task = task; + this.executeAtNanos = executeAtNanos; + } + + @Override + public void run() { + if (!cancelled.get()) { + task.run(); + } + } + + @Override + public long getDelay(TimeUnit unit) { + long delayNanos = executeAtNanos - System.nanoTime(); + return unit.convert(delayNanos, TimeUnit.NANOSECONDS); + } + + @Override + public int compareTo(java.util.concurrent.Delayed o) { + if (this == o) { + return 0; + } + long diff = getDelay(TimeUnit.NANOSECONDS) - o.getDelay(TimeUnit.NANOSECONDS); + return Long.signum(diff); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return cancelled.compareAndSet(false, true); + } + + @Override + public boolean isCancelled() { + return cancelled.get(); + } + + @Override + public boolean isDone() { + return cancelled.get(); + } + + @Override + public Object get() { + return null; + } + + @Override + public Object get(long timeout, TimeUnit unit) { + return null; + } + } + + /** A ScheduledFuture for Callable tasks that holds the result. */ + private static final class CallableScheduledFuture implements ScheduledFuture { + private final long executeAtNanos; + private volatile V result; + private volatile Throwable exception; + private volatile boolean done = false; + private final Object lock = new Object(); + + CallableScheduledFuture(Callable task, long executeAtNanos) { + this.executeAtNanos = executeAtNanos; + } + + void complete(V result) { + synchronized (lock) { + if (done) { + return; + } + this.result = result; + this.done = true; + lock.notifyAll(); + } + } + + void completeExceptionally(Throwable exception) { + synchronized (lock) { + if (done) { + return; + } + this.exception = exception; + this.done = true; + lock.notifyAll(); + } + } + + @Override + public long getDelay(TimeUnit unit) { + long delayNanos = executeAtNanos - System.nanoTime(); + return unit.convert(delayNanos, TimeUnit.NANOSECONDS); + } + + @Override + public int compareTo(java.util.concurrent.Delayed o) { + if (this == o) { + return 0; + } + long diff = getDelay(TimeUnit.NANOSECONDS) - o.getDelay(TimeUnit.NANOSECONDS); + return Long.signum(diff); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; // Cannot cancel after scheduled + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return done; + } + + @Override + public V get() throws InterruptedException, ExecutionException { + synchronized (lock) { + while (!done) { + lock.wait(); + } + if (exception != null) { + throw new ExecutionException(exception); + } + return result; + } + } + + @Override + public V get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + synchronized (lock) { + if (!done) { + lock.wait(unit.toMillis(timeout)); + } + if (!done) { + throw new TimeoutException(); + } + if (exception != null) { + throw new ExecutionException(exception); + } + return result; + } + } + } + + // Additional ExecutorService methods + + @Override + public void execute(Runnable command) { + submitInternal(command); + } + + @Override + public List shutdownNow() { + shutdown(); + return List.of(); // Cannot retrieve pending tasks in this implementation + } + + @Override + public boolean isTerminated() { + return isShutdown(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return true; // No background threads to wait for + } + + @Override + public List> invokeAll(Collection> tasks) { + throw new UnsupportedOperationException("invokeAll not supported"); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) { + throw new UnsupportedOperationException("invokeAll not supported"); + } + + @Override + public T invokeAny(Collection> tasks) + throws InterruptedException, ExecutionException { + throw new UnsupportedOperationException("invokeAny not supported"); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + throw new UnsupportedOperationException("invokeAny not supported"); + } + + /** + * Processes all pending tasks that are ready to execute. + * + *

This method must be called from the owner thread (the thread that created this service). It + * will execute all immediate tasks and any scheduled tasks whose delay has elapsed. + * + * @return the number of tasks executed + * @throws IllegalStateException if called from a thread other than the owner thread + */ + public int processPendingTasks() { + int tasksExecuted = 0; + long currentTime = System.nanoTime(); + + // Process immediate tasks + Runnable task; + while ((task = immediateTasks.poll()) != null) { + try { + task.run(); + tasksExecuted++; + } catch (Throwable t) { + handleUncaughtException(t); + } + } + + // Process scheduled tasks that are ready + synchronized (scheduledTasksLock) { + while (!scheduledTasks.isEmpty()) { + ScheduledTask scheduledTask = scheduledTasks.peek(); + if (scheduledTask.executeAtNanos <= currentTime) { + scheduledTasks.poll(); + try { + scheduledTask.task.run(); + tasksExecuted++; + } catch (Throwable t) { + handleUncaughtException(t); + } + } else { + break; // Tasks are sorted by time, so we can stop here + } + } + } + + return tasksExecuted; + } + + /** + * Returns true if there are any tasks pending execution. + * + * @return true if tasks are pending + */ + public boolean hasPendingTasks() { + if (!immediateTasks.isEmpty()) { + return true; + } + synchronized (scheduledTasksLock) { + if (scheduledTasks.isEmpty()) { + return false; + } + long currentTime = System.nanoTime(); + ScheduledTask next = scheduledTasks.peek(); + return next != null && next.executeAtNanos <= currentTime; + } + } + + /** + * Returns the number of nanoseconds until the next scheduled task is ready, or -1 if there are no + * scheduled tasks. + * + * @return nanoseconds until next task, or -1 if none + */ + public long getNextTaskDelayNanos() { + if (!immediateTasks.isEmpty()) { + return 0; + } + synchronized (scheduledTasksLock) { + ScheduledTask next = scheduledTasks.peek(); + if (next == null) { + return -1; + } + long delay = next.executeAtNanos - System.nanoTime(); + return Math.max(0, delay); + } + } + + /** + * Waits until tasks are available or the timeout expires. + * + *

This method blocks until either: + * + *

    + *
  • A new task is submitted (immediate or scheduled) + *
  • The specified timeout expires + *
  • The thread is interrupted + *
+ * + * @param timeoutNanos maximum time to wait in nanoseconds, or -1 to wait with a default timeout + * @throws InterruptedException if the thread is interrupted while waiting + */ + public void waitForTasks(long timeoutNanos) throws InterruptedException { + synchronized (waitLock) { + if (timeoutNanos > 0) { + long timeoutMillis = timeoutNanos / 1_000_000; + int timeoutNanosRemainder = (int) (timeoutNanos % 1_000_000); + waitLock.wait(timeoutMillis, timeoutNanosRemainder); + } else if (timeoutNanos == -1) { + waitLock.wait(10); + } + } + } + + /** + * Shuts down this service. No new tasks will be accepted after shutdown. + * + *

Pending tasks can still be processed with {@link #processPendingTasks()}. + */ + public void shutdown() { + shutdown.set(true); + synchronized (waitLock) { + waitLock.notifyAll(); + } + } + + /** + * Returns true if this service has been shut down. + * + * @return true if shut down + */ + public boolean isShutdown() { + return shutdown.get(); + } + + /** + * Returns the thread ID of the owner thread. + * + * @return the owner thread ID + */ + public long getOwnerThreadId() { + return ownerThreadId; + } + + private void handleUncaughtException(Throwable t) { + Thread currentThread = Thread.currentThread(); + Thread.UncaughtExceptionHandler handler = currentThread.getUncaughtExceptionHandler(); + if (handler != null) { + handler.uncaughtException(currentThread, t); + } else { + System.err.println("Uncaught exception in YdocScheduledExecutorService:"); + t.printStackTrace(); + } + } + + /** Internal class representing a scheduled task with its execution time. */ + private static final class ScheduledTask implements Comparable { + final Runnable task; + final long executeAtNanos; + + ScheduledTask(Runnable task, long executeAtNanos) { + this.task = task; + this.executeAtNanos = executeAtNanos; + } + + @Override + public int compareTo(ScheduledTask other) { + return Long.compare(this.executeAtNanos, other.executeAtNanos); + } + } + + /** A future-like result holder for scheduled tasks. */ + public static final class CompletableFuture { + private volatile V result; + private volatile Throwable exception; + private volatile boolean done = false; + private final Object lock = new Object(); + + void complete(V result) { + synchronized (lock) { + if (done) { + return; + } + this.result = result; + this.done = true; + lock.notifyAll(); + } + } + + void completeExceptionally(Throwable exception) { + synchronized (lock) { + if (done) { + return; + } + this.exception = exception; + this.done = true; + lock.notifyAll(); + } + } + + /** + * Returns true if this future is complete. + * + * @return true if complete + */ + public boolean isDone() { + return done; + } + + /** + * Gets the result, blocking until it's available. + * + * @return the result + * @throws Exception if the task threw an exception + */ + public V get() throws Exception { + synchronized (lock) { + while (!done) { + lock.wait(); + } + if (exception != null) { + if (exception instanceof Exception) { + throw (Exception) exception; + } else { + throw new Exception(exception); + } + } + return result; + } + } + + /** + * Gets the result without blocking, or returns null if not complete. + * + * @return the result, or null if not complete + * @throws RuntimeException if the task threw an exception + */ + public V getNow() { + if (!done) { + return null; + } + if (exception != null) { + if (exception instanceof RuntimeException) { + throw (RuntimeException) exception; + } else { + throw new RuntimeException(exception); + } + } + return result; + } + } +} diff --git a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YjsCallbacksSynchronized.java b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YjsCallbacksSynchronized.java new file mode 100644 index 000000000000..baa7f18be0d2 --- /dev/null +++ b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YjsCallbacksSynchronized.java @@ -0,0 +1,23 @@ +package org.enso.ydoc.server; + +import org.enso.ydoc.api.YjsChannel; +import org.enso.ydoc.api.YjsChannelCallbacks; +import org.graalvm.polyglot.HostAccess; + +public final class YjsCallbacksSynchronized implements YjsChannelCallbacks { + + private final YjsChannelCallbacks callbacks; + private final YdocScheduledExecutorService executor; + + YjsCallbacksSynchronized(YjsChannelCallbacks callbacks, YdocScheduledExecutorService executor) { + this.callbacks = callbacks; + this.executor = executor; + } + + @Override + @HostAccess.Export + public void onConnect(YjsChannel channel) { + var synchronizedChannel = new YjsChannelSynchronized(channel, this.executor); + this.callbacks.onConnect(synchronizedChannel); + } +} diff --git a/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YjsChannelSynchronized.java b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YjsChannelSynchronized.java new file mode 100644 index 000000000000..2dc104d59038 --- /dev/null +++ b/lib/java/ydoc-server/src/main/java/org/enso/ydoc/server/YjsChannelSynchronized.java @@ -0,0 +1,25 @@ +package org.enso.ydoc.server; + +import java.util.function.Consumer; +import org.enso.ydoc.api.YjsChannel; + +public final class YjsChannelSynchronized implements YjsChannel { + + private final YjsChannel channel; + private final YdocScheduledExecutorService executor; + + public YjsChannelSynchronized(YjsChannel channel, YdocScheduledExecutorService executor) { + this.channel = channel; + this.executor = executor; + } + + @Override + public void send(Object message) { + executor.submit(() -> channel.send(message)); + } + + @Override + public void subscribe(Consumer messageHandler) { + executor.submit(() -> channel.subscribe(messageHandler)); + } +} diff --git a/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reachability-metadata.json b/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reachability-metadata.json index b3029ce3acf0..04a8f04e65a5 100644 --- a/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reachability-metadata.json +++ b/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reachability-metadata.json @@ -539,6 +539,15 @@ } ] }, + { + "type": "org.enso.languageserver.http.server.BinaryYdocServer$BinaryServerCallbacks", + "methods": [ + { + "name": "onConnect", + "parameterTypes": ["org.enso.ydoc.api.YjsChannel"] + } + ] + }, { "type": "org.enso.ydoc.polyfill.web.URL", "fields": [ @@ -548,7 +557,13 @@ { "name": "searchParams" } - ] + ], + "methods": [ + { + "name": "toString", + "parameterTypes": [] + } + ] }, { "type": "org.enso.ydoc.polyfill.web.URL$URLSearchParams", @@ -564,6 +579,15 @@ { "type": "org.enso.ydoc.server.Main", "methods": [ + { + "name": "launch", + "parameterTypes": [ + "java.lang.String", + "java.lang.String", + "org.enso.ydoc.api.YjsChannelCallbacks", + "org.enso.ydoc.api.YjsChannelCallbacks" + ] + }, { "name": "main", "parameterTypes": [ @@ -842,6 +866,27 @@ ] } ] + }, + { + "type": { + "proxy": [ + "org.enso.ydoc.api.YjsChannelCallbacks" + ] + } + }, + { + "type": { + "proxy": [ + "java.util.function.Consumer" + ] + } + }, + { + "type": { + "proxy": [ + "org.enso.ydoc.api.YjsChannel" + ] + } } ] -} \ No newline at end of file +} diff --git a/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reflect-config.json b/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reflect-config.json index 67bb9d33ba71..15e17482fb41 100644 --- a/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reflect-config.json +++ b/lib/java/ydoc-server/src/main/resources/META-INF/native-image/org/enso/ydoc/reflect-config.json @@ -273,6 +273,15 @@ "parameterTypes": ["java.lang.String[]"] } ] + }, + { + "name": "org.enso.ydoc.server.YjsChannelSynchronized", + "methods": [ + { + "name": "subscribe", + "parameterTypes": ["java.util.function.Consumer"] + } + ] } ] diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocScheduledExecutorServiceTest.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocScheduledExecutorServiceTest.java new file mode 100644 index 000000000000..4c05987e9756 --- /dev/null +++ b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocScheduledExecutorServiceTest.java @@ -0,0 +1,298 @@ +package org.enso.ydoc.server; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; + +public class YdocScheduledExecutorServiceTest { + + @Test + public void testImmediateTaskExecution() { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + AtomicInteger counter = new AtomicInteger(0); + + // Submit immediate tasks + service.submit(() -> counter.incrementAndGet()); + service.submit(() -> counter.incrementAndGet()); + service.submit(() -> counter.incrementAndGet()); + + // Process tasks + int executed = service.processPendingTasks(); + + assertEquals(3, executed); + assertEquals(3, counter.get()); + } + + @Test + public void testTaskExecutionOrder() { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + List executionOrder = new ArrayList<>(); + + // Submit tasks in order + service.submit(() -> executionOrder.add(1)); + service.submit(() -> executionOrder.add(2)); + service.submit(() -> executionOrder.add(3)); + + service.processPendingTasks(); + + // Verify FIFO order + assertEquals(List.of(1, 2, 3), executionOrder); + } + + @Test + public void testScheduledTaskWithDelay() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + AtomicInteger counter = new AtomicInteger(0); + + // Schedule a task with delay + service.schedule(() -> counter.incrementAndGet(), 50, TimeUnit.MILLISECONDS); + + // Process immediately - should not execute yet + int executed1 = service.processPendingTasks(); + assertEquals(0, executed1); + assertEquals(0, counter.get()); + + // Wait for the delay + Thread.sleep(60); + + // Process again - should execute now + int executed2 = service.processPendingTasks(); + assertEquals(1, executed2); + assertEquals(1, counter.get()); + } + + @Test + public void testMultipleScheduledTasks() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + List executionOrder = new ArrayList<>(); + + // Schedule tasks with different delays + service.schedule(() -> executionOrder.add(3), 60, TimeUnit.MILLISECONDS); + service.schedule(() -> executionOrder.add(1), 20, TimeUnit.MILLISECONDS); + service.schedule(() -> executionOrder.add(2), 40, TimeUnit.MILLISECONDS); + + // Wait for all delays to pass + Thread.sleep(70); + + // Process - should execute in order of delay + service.processPendingTasks(); + + assertEquals(List.of(1, 2, 3), executionOrder); + } + + @Test + public void testMixedImmediateAndScheduledTasks() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + List executionOrder = new ArrayList<>(); + + // Mix immediate and scheduled tasks + service.submit(() -> executionOrder.add("immediate1")); + service.schedule(() -> executionOrder.add("scheduled1"), 30, TimeUnit.MILLISECONDS); + service.submit(() -> executionOrder.add("immediate2")); + + // Process immediate tasks first + service.processPendingTasks(); + assertEquals(List.of("immediate1", "immediate2"), executionOrder); + + // Wait and process scheduled task + Thread.sleep(40); + service.processPendingTasks(); + assertEquals(List.of("immediate1", "immediate2", "scheduled1"), executionOrder); + } + + @Test + public void testCallableWithResult() throws Exception { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + + var future = service.submit(() -> "Hello World"); + + service.processPendingTasks(); + + assertTrue(future.isDone()); + assertEquals("Hello World", future.get()); + } + + @Test + public void testCallableWithException() { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + + var future = + service.submit( + () -> { + throw new RuntimeException("Test exception"); + }); + + service.processPendingTasks(); + + assertTrue(future.isDone()); + try { + future.get(); + fail("Expected exception"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Test exception")); + } + } + + @Test + public void testScheduledCallable() throws Exception { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + + var future = service.schedule(() -> 42, 30, TimeUnit.MILLISECONDS); + + assertFalse(future.isDone()); + + Thread.sleep(40); + service.processPendingTasks(); + + assertTrue(future.isDone()); + assertEquals(Integer.valueOf(42), future.get()); + } + + @Test + public void testTasksExecuteOnOwnerThread() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + long ownerThreadId = Thread.currentThread().threadId(); + + List immediateTaskThreadIds = new ArrayList<>(); + List scheduledTaskThreadIds = new ArrayList<>(); + + // Submit immediate tasks from different threads + service.submit(() -> immediateTaskThreadIds.add(Thread.currentThread().threadId())); + + Thread submitterThread = + new Thread( + () -> { + service.submit(() -> immediateTaskThreadIds.add(Thread.currentThread().threadId())); + service.schedule( + () -> scheduledTaskThreadIds.add(Thread.currentThread().threadId()), + 30, + TimeUnit.MILLISECONDS); + }); + submitterThread.start(); + submitterThread.join(); + + // Schedule a task from the owner thread + service.schedule( + () -> scheduledTaskThreadIds.add(Thread.currentThread().threadId()), + 50, + TimeUnit.MILLISECONDS); + + // Process immediate tasks on owner thread + service.processPendingTasks(); + + // All immediate tasks should have executed on owner thread + assertEquals(2, immediateTaskThreadIds.size()); + for (Long threadId : immediateTaskThreadIds) { + assertEquals(ownerThreadId, threadId.longValue()); + } + + // Wait for scheduled tasks + Thread.sleep(60); + service.processPendingTasks(); + + // All scheduled tasks should have executed on owner thread + assertEquals(2, scheduledTaskThreadIds.size()); + for (Long threadId : scheduledTaskThreadIds) { + assertEquals(ownerThreadId, threadId.longValue()); + } + } + + @Test + public void testHasPendingTasks() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + + assertFalse(service.hasPendingTasks()); + + service.submit(() -> {}); + assertTrue(service.hasPendingTasks()); + + service.processPendingTasks(); + assertFalse(service.hasPendingTasks()); + + service.schedule(() -> {}, 30, TimeUnit.MILLISECONDS); + assertFalse(service.hasPendingTasks()); // Not ready yet + + Thread.sleep(40); + assertTrue(service.hasPendingTasks()); // Now ready + } + + @Test + public void testGetNextTaskDelay() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + + assertEquals(-1, service.getNextTaskDelayNanos()); // No tasks + + service.submit(() -> {}); + assertEquals(0, service.getNextTaskDelayNanos()); // Immediate task + + service.processPendingTasks(); + assertEquals(-1, service.getNextTaskDelayNanos()); // No tasks again + + service.schedule(() -> {}, 100, TimeUnit.MILLISECONDS); + long delay = service.getNextTaskDelayNanos(); + assertTrue(delay > 0 && delay <= TimeUnit.MILLISECONDS.toNanos(100)); + } + + @Test + public void testShutdown() { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + + assertFalse(service.isShutdown()); + + service.shutdown(); + assertTrue(service.isShutdown()); + + // Should not accept new tasks after shutdown + try { + service.submit(() -> {}); + fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + // Expected + } + } + + @Test + public void testExceptionHandlingInTask() { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + AtomicInteger counter = new AtomicInteger(0); + + // Submit task that throws exception + service.submit( + () -> { + throw new RuntimeException("Test exception"); + }); + + // Submit another task to verify service continues + service.submit(() -> counter.incrementAndGet()); + + // Process tasks - should handle exception and continue + service.processPendingTasks(); + + // Second task should have executed despite first one throwing + assertEquals(1, counter.get()); + } + + @Test + public void testEventLoopPattern() throws InterruptedException { + YdocScheduledExecutorService service = new YdocScheduledExecutorService(); + AtomicInteger counter = new AtomicInteger(0); + + // Schedule recurring tasks + service.submit(() -> counter.incrementAndGet()); + service.schedule(() -> counter.incrementAndGet(), 10, TimeUnit.MILLISECONDS); + service.schedule(() -> counter.incrementAndGet(), 20, TimeUnit.MILLISECONDS); + + // Simulate event loop + for (int i = 0; i < 5; i++) { + service.processPendingTasks(); + Thread.sleep(10); + } + + // All tasks should have executed + assertEquals(3, counter.get()); + } +} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocTest.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocTest.java index 2582e0758c75..dcbec23212f3 100644 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocTest.java +++ b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/YdocTest.java @@ -1,32 +1,22 @@ package org.enso.ydoc.server; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; import io.helidon.common.buffers.BufferData; import io.helidon.http.Status; import io.helidon.webclient.api.HttpClientResponse; import io.helidon.webclient.api.WebClient; import io.helidon.webclient.websocket.WsClient; -import io.helidon.webserver.WebServer; -import io.helidon.webserver.websocket.WsRouting; import io.helidon.websocket.WsListener; import io.helidon.websocket.WsSession; -import java.util.List; -import java.util.UUID; +import java.io.IOException; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; -import org.enso.ydoc.server.jsonrpc.JsonRpcRequest; -import org.enso.ydoc.server.jsonrpc.JsonRpcResponse; -import org.enso.ydoc.server.jsonrpc.model.ContentRoot; -import org.enso.ydoc.server.jsonrpc.model.FilePath; -import org.enso.ydoc.server.jsonrpc.model.FileSystemObject; -import org.enso.ydoc.server.jsonrpc.model.WriteCapability; -import org.enso.ydoc.server.jsonrpc.model.result.FileListResult; -import org.enso.ydoc.server.jsonrpc.model.result.InitProtocolConnectionResult; -import org.enso.ydoc.server.jsonrpc.model.result.TextOpenFileResult; +import org.enso.ydoc.api.YjsChannel; +import org.enso.ydoc.api.YjsChannelCallbacks; +import org.enso.ydoc.polyfill.web.WebEnvironment; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -44,87 +34,96 @@ public class YdocTest { private static final Logger log = LoggerFactory.getLogger(YdocTest.class); private Ydoc ydoc; - private ExecutorService webServerExecutor; - private WebServer ls; - - private static WebServer startWebSocketServer(ExecutorService executor) { - var routing = WsRouting.builder().endpoint("/", new LanguageServerConnection()); - var ws = - WebServer.builder().host("localhost").port(WEB_SERVER_PORT).addRouting(routing).build(); - - executor.submit(ws::start); - - return ws; - } + private ExecutorService executor; private static String ydocUrl(String doc) { - return YDOC_URL + doc + "?ls=" + WEB_SERVER_URL; + return YDOC_URL + doc + "?ls=" + WEB_SERVER_URL + "&data=" + WEB_SERVER_URL; } @Before - public void setup() { - webServerExecutor = Executors.newSingleThreadExecutor(); - ls = startWebSocketServer(webServerExecutor); - ydoc = Ydoc.builder().build(); + public void setup() throws Exception { + executor = + Executors.newSingleThreadExecutor( + (r) -> { + var t = new Thread(r); + t.setName("ydoc-thread"); + return t; + }); } @After public void tearDown() throws Exception { - ls.stop(); - webServerExecutor.shutdown(); - var stopped = webServerExecutor.awaitTermination(3, TimeUnit.SECONDS); - if (!stopped) { - var pending = webServerExecutor.shutdownNow(); - log.error("Executor pending [{}] tasks: [{}].", pending.size(), pending); + if (executor != null) { + executor.shutdownNow(); + executor.awaitTermination(3, TimeUnit.SECONDS); + } + if (ydoc != null) { + ydoc.close(); } - ydoc.close(); } @Test(timeout = 60000) - public void initialize() throws Exception { + public void start() throws Exception { var queue = new LinkedBlockingQueue(); + var jsonOnConnectLatch = new CountDownLatch(1); + var binaryOnConnectLatch = new CountDownLatch(1); + + YjsChannelCallbacks jsonCallbacks = + (YjsChannel channel) -> { + log.debug("Json onConnect called with channel: {}", channel); + jsonOnConnectLatch.countDown(); + }; + + YjsChannelCallbacks binaryCallbacks = + (YjsChannel channel) -> { + log.debug("Binary onConnect called with channel: {}", channel); + binaryOnConnectLatch.countDown(); + }; + + executor.submit( + () -> { + try { + var hostAccess = + WebEnvironment.defaultHostAccess + .allowImplementations(YjsChannel.class) + .allowPublicAccess(true); + ydoc = + Ydoc.builder() + .hostAccessBuilder(hostAccess) + .jsonChannelCallbacks(jsonCallbacks) + .binaryChannelCallbacks(binaryCallbacks) + .build(); + ydoc.start(); + } catch (IOException e) { + e.printStackTrace(); + Assert.fail("Ydoc.start()"); + } + }); - ydoc.start(); - + var connected = false; var ws = WsClient.builder().build(); - ws.connect(ydocUrl("index"), new DashboardConnection(queue)); - - var ok1 = queue.take(); - Assert.assertTrue(ok1.debugDataHex(), BufferDataUtil.isOk(ok1)); + while (!connected) { + try { + ws.connect(ydocUrl("index"), new DashboardConnection(queue)); + connected = true; + } catch (Exception _ignore) { + } + } - var buffer = queue.take(); - var uuid = BufferDataUtil.readUUID(buffer); - WsClient.builder().build().connect(ydocUrl(uuid.toString()), new DashboardConnection(queue)); + Assert.assertTrue("Client should be connected", connected); - var ok2 = queue.take(); - Assert.assertTrue(ok2.debugDataHex(), BufferDataUtil.isOk(ok2)); + var jsonOnConnectCalled = jsonOnConnectLatch.await(30, TimeUnit.SECONDS); + var binaryOnConnectCalled = binaryOnConnectLatch.await(30, TimeUnit.SECONDS); + Assert.assertTrue( + "Json onConnect callback should be called after client connects", jsonOnConnectCalled); + Assert.assertTrue( + "Binary onConnect callback should be called after client connects", binaryOnConnectCalled); WebClient http = WebClient.create(); HttpClientResponse healthcheckResponse = http.get(HEALTHCHECK_URL).request(); Assert.assertEquals(Status.OK_200, healthcheckResponse.status()); } - private static final class BufferDataUtil { - - private static final int UUID_BYTES = 36; - private static final int SUFFIX_BYTES = 3; - - private static boolean isOk(BufferData data) { - return data.readInt16() == 0; - } - - private static UUID readUUID(BufferData data) { - try { - data.skip(data.available() - UUID_BYTES - SUFFIX_BYTES); - var uuidString = data.readString(UUID_BYTES); - return UUID.fromString(uuidString); - } catch (Exception e) { - log.error("Failed to read UUID of\n{}", data.debugDataHex()); - throw e; - } - } - } - private static final class DashboardConnection implements WsListener { private static final Logger log = LoggerFactory.getLogger(DashboardConnection.class); @@ -147,77 +146,4 @@ public void onMessage(WsSession session, String text, boolean last) { log.error("Got unexpected message [{}].", text); } } - - private static final class LanguageServerConnection implements WsListener { - - private static final String METHOD_INIT_PROTOCOL_CONNECTION = "session/initProtocolConnection"; - private static final String METHOD_CAPABILITY_ACQUIRE = "capability/acquire"; - private static final String METHOD_FILE_LIST = "file/list"; - private static final String METHOD_TEXT_OPEN_FILE = "text/openFile"; - - private static final UUID PROJECT_ROOT_ID = new UUID(0, 1); - - private static final Logger log = LoggerFactory.getLogger(LanguageServerConnection.class); - - private static final ObjectMapper objectMapper = new ObjectMapper(); - - private LanguageServerConnection() {} - - @Override - public void onMessage(WsSession session, String text, boolean last) { - log.debug("Got message [{}].", text); - - try { - var request = objectMapper.readValue(text, JsonRpcRequest.class); - - JsonRpcResponse jsonRpcResponse = null; - - switch (request.method()) { - case METHOD_INIT_PROTOCOL_CONNECTION -> { - var contentRoots = - List.of( - new ContentRoot("Project", PROJECT_ROOT_ID), - new ContentRoot("Home", new UUID(0, 2)), - new ContentRoot("FileSystemRoot", new UUID(0, 3), "/")); - var initProtocolConnectionResult = - new InitProtocolConnectionResult("0.0.0-dev", "0.0.0-dev", contentRoots); - jsonRpcResponse = new JsonRpcResponse(request.id(), initProtocolConnectionResult); - } - case METHOD_CAPABILITY_ACQUIRE -> jsonRpcResponse = JsonRpcResponse.ok(request.id()); - case METHOD_FILE_LIST -> { - var paths = - List.of( - FileSystemObject.file( - "Main.enso", new FilePath(PROJECT_ROOT_ID, List.of("src")))); - var fileListResult = new FileListResult(paths); - jsonRpcResponse = new JsonRpcResponse(request.id(), fileListResult); - } - case METHOD_TEXT_OPEN_FILE -> { - var options = - new WriteCapability.Options( - new FilePath(PROJECT_ROOT_ID, List.of("src", "Main.enso"))); - var writeCapability = new WriteCapability("text/canEdit", options); - var textOpenFileResult = - new TextOpenFileResult( - writeCapability, - "main =", - "e5aeae8609bd90f94941d4227e6ec1e0f069d3318fb7bd93ffe4d391"); - jsonRpcResponse = new JsonRpcResponse(request.id(), textOpenFileResult); - } - } - - if (jsonRpcResponse != null) { - var response = objectMapper.writeValueAsString(jsonRpcResponse); - - log.debug("Sending [{}].", response); - session.send(response, true); - } else { - log.error("Unknown request."); - } - } catch (JsonProcessingException e) { - log.error("Failed to parse JSON.", e); - Assert.fail(e.getMessage()); - } - } - } } diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/JsonRpcRequest.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/JsonRpcRequest.java deleted file mode 100644 index 3400f06df5a1..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/JsonRpcRequest.java +++ /dev/null @@ -1,5 +0,0 @@ -package org.enso.ydoc.server.jsonrpc; - -import com.fasterxml.jackson.databind.JsonNode; - -public record JsonRpcRequest(String jsonrpc, String id, String method, JsonNode params) {} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/JsonRpcResponse.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/JsonRpcResponse.java deleted file mode 100644 index 2fd67b711ffd..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/JsonRpcResponse.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.enso.ydoc.server.jsonrpc; - -import org.enso.ydoc.server.jsonrpc.model.result.Result; - -public record JsonRpcResponse(String jsonrpc, String id, Result result) { - - private static final String JSONRPC_VERSION_2_0 = "2.0"; - - public JsonRpcResponse(String id, Result result) { - this(JSONRPC_VERSION_2_0, id, result); - } - - public static JsonRpcResponse ok(String id) { - return new JsonRpcResponse(id, null); - } -} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/ContentRoot.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/ContentRoot.java deleted file mode 100644 index 7021f96c154e..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/ContentRoot.java +++ /dev/null @@ -1,10 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model; - -import java.util.UUID; - -public record ContentRoot(String type, UUID id, String path) { - - public ContentRoot(String type, UUID id) { - this(type, id, null); - } -} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/FilePath.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/FilePath.java deleted file mode 100644 index 661ecd065392..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/FilePath.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model; - -import java.util.List; -import java.util.UUID; - -public record FilePath(UUID rootId, List segments) {} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/FileSystemObject.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/FileSystemObject.java deleted file mode 100644 index fd9174a4878b..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/FileSystemObject.java +++ /dev/null @@ -1,8 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model; - -public record FileSystemObject(String type, String name, FilePath path) { - - public static FileSystemObject file(String name, FilePath path) { - return new FileSystemObject("File", name, path); - } -} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/WriteCapability.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/WriteCapability.java deleted file mode 100644 index 153ef52555ff..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/WriteCapability.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model; - -public record WriteCapability(String method, Options registerOptions) { - - public record Options(FilePath path) {} -} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/FileListResult.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/FileListResult.java deleted file mode 100644 index 7f27853fa613..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/FileListResult.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model.result; - -import java.util.List; -import org.enso.ydoc.server.jsonrpc.model.FileSystemObject; - -public record FileListResult(List paths) implements Result {} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/InitProtocolConnectionResult.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/InitProtocolConnectionResult.java deleted file mode 100644 index 226d21934007..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/InitProtocolConnectionResult.java +++ /dev/null @@ -1,7 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model.result; - -import java.util.List; -import org.enso.ydoc.server.jsonrpc.model.ContentRoot; - -public record InitProtocolConnectionResult( - String ensoVersion, String currentEdition, List contentRoots) implements Result {} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/Result.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/Result.java deleted file mode 100644 index 16e0058b05d6..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/Result.java +++ /dev/null @@ -1,3 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model.result; - -public interface Result {} diff --git a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/TextOpenFileResult.java b/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/TextOpenFileResult.java deleted file mode 100644 index 1bbd4ff75244..000000000000 --- a/lib/java/ydoc-server/src/test/java/org/enso/ydoc/server/jsonrpc/model/result/TextOpenFileResult.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.enso.ydoc.server.jsonrpc.model.result; - -import org.enso.ydoc.server.jsonrpc.model.WriteCapability; - -public record TextOpenFileResult( - WriteCapability writeCapability, String content, String currentVersion) implements Result {} diff --git a/lib/scala/json-rpc-server/src/main/java/module-info.java b/lib/scala/json-rpc-server/src/main/java/module-info.java index 6c39ae9969f0..6b48fdcf5345 100644 --- a/lib/scala/json-rpc-server/src/main/java/module-info.java +++ b/lib/scala/json-rpc-server/src/main/java/module-info.java @@ -2,6 +2,7 @@ requires scala.library; requires org.enso.scala.wrapper; requires org.enso.akka.wrapper; + requires org.enso.ydoc.api; requires org.slf4j; exports org.enso.jsonrpc; diff --git a/lib/scala/json-rpc-server/src/main/resources/META-INF/native-image/org/enso/jsonrpcserver/reflect-config.json b/lib/scala/json-rpc-server/src/main/resources/META-INF/native-image/org/enso/jsonrpcserver/reflect-config.json new file mode 100644 index 000000000000..50d7fa135e5c --- /dev/null +++ b/lib/scala/json-rpc-server/src/main/resources/META-INF/native-image/org/enso/jsonrpcserver/reflect-config.json @@ -0,0 +1,20 @@ +[ + { + "name": "org.enso.jsonrpc.YdocJsonRpcServer$OnMessageHandler", + "methods": [ + { + "name": "accept", + "parameterTypes": ["java.lang.Object"] + } + ] + }, + { + "name": "org.enso.jsonrpc.YdocJsonRpcServer$ServerCallbacks", + "methods": [ + { + "name": "onConnect", + "parameterTypes": ["org.enso.ydoc.api.YjsChannel"] + } + ] + } +] diff --git a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/JsonRpcServer.scala b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/JsonRpcServer.scala index 92a7e33128f3..5f720e5e0365 100644 --- a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/JsonRpcServer.scala +++ b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/JsonRpcServer.scala @@ -42,14 +42,13 @@ class JsonRpcServer( private val messageCallbackSinks = messageCallbacks.map(Sink.foreach[WebMessage]) - private def newUser(port: Int): Flow[Message, Message, NotUsed] = { + private def newUser: Flow[Message, Message, NotUsed] = { val messageHandler = system.actorOf( Props( new MessageHandlerSupervisor( clientControllerFactory, - protocolFactory, - port + protocolFactory ) ), s"message-handler-supervisor-${UUID.randomUUID()}" @@ -77,11 +76,11 @@ class JsonRpcServer( Sink.actorRef[MessageHandler.WebMessage]( messageHandler, { logger.trace("JSON sink stream finished with no failure") - MessageHandler.Disconnected(port) + MessageHandler.Disconnected() }, { e: Throwable => logger.trace("JSON sink stream finished with a failure", e) - MessageHandler.Disconnected(port) + MessageHandler.Disconnected() } ) ) @@ -95,7 +94,7 @@ class JsonRpcServer( OverflowStrategy.fail ) .mapMaterializedValue { outActor => - messageHandler ! MessageHandler.Connected(outActor, port) + messageHandler ! MessageHandler.Connected(outActor) NotUsed } .map((outMsg: MessageHandler.WebMessage) => TextMessage(outMsg.message)) @@ -109,7 +108,7 @@ class JsonRpcServer( override protected def serverRoute(port: Int): Route = { val webSocketEndpoint = path(config.path) { - get { handleWebSocketMessages(newUser(port)) } + get { handleWebSocketMessages(newUser) } } optionalEndpoints.foldLeft(webSocketEndpoint) { (chain, next) => @@ -152,6 +151,6 @@ object JsonRpcServer { ) } - case class WebConnect(webActor: ActorRef, port: Int) + case class WebConnect(webActor: ActorRef) } diff --git a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandler.scala b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandler.scala index e28950e012d2..3722dbf5894f 100644 --- a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandler.scala +++ b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandler.scala @@ -20,7 +20,7 @@ class MessageHandler(protocolFactory: ProtocolFactory, controller: ActorRef) * @return the actor behavior. */ override def receive: Receive = { - case MessageHandler.Connected(webConnection, _) => + case MessageHandler.Connected(webConnection) => unstashAll() context.become(established(webConnection, Map())) case _ => stash() @@ -38,8 +38,8 @@ class MessageHandler(protocolFactory: ProtocolFactory, controller: ActorRef) ): Receive = { case MessageHandler.WebMessage(msg) => handleWebMessage(msg, webConnection, awaitingResponses) - case MessageHandler.Disconnected(port) => - controller ! MessageHandler.Disconnected(port) + case MessageHandler.Disconnected() => + controller ! MessageHandler.Disconnected() context.stop(self) case request: Request[Method, Any] => issueRequest(request, webConnection, awaitingResponses) @@ -192,10 +192,10 @@ object MessageHandler { /** A control message used for [[MessageHandler]] initializations * @param webConnection the actor representing the web. */ - case class Connected(webConnection: ActorRef, port: Int) + case class Connected(webConnection: ActorRef) /** A control message used to notify the controller about * the connection being closed. */ - case class Disconnected(port: Int) + case class Disconnected() } diff --git a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandlerSupervisor.scala b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandlerSupervisor.scala index 7aba541da590..7c2a52125b40 100644 --- a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandlerSupervisor.scala +++ b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/MessageHandlerSupervisor.scala @@ -20,8 +20,7 @@ import java.util.UUID */ final class MessageHandlerSupervisor( clientControllerFactory: ClientControllerFactory, - protocolFactory: ProtocolFactory, - port: Int + protocolFactory: ProtocolFactory ) extends Actor with LazyLogging with Stash { @@ -61,7 +60,7 @@ final class MessageHandlerSupervisor( s"message-handler-$clientId" ) context.watch(messageHandler) - clientActor ! JsonRpcServer.WebConnect(messageHandler, port) + clientActor ! JsonRpcServer.WebConnect(messageHandler) context.become(initialized(messageHandler)) unstashAll() diff --git a/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/YdocJsonRpcServer.scala b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/YdocJsonRpcServer.scala new file mode 100644 index 000000000000..b86affec2fbf --- /dev/null +++ b/lib/scala/json-rpc-server/src/main/scala/org/enso/jsonrpc/YdocJsonRpcServer.scala @@ -0,0 +1,137 @@ +package org.enso.jsonrpc + +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server.Route +import com.typesafe.scalalogging.LazyLogging +import org.enso.jsonrpc.MessageHandler +import org.enso.ydoc.api.YjsChannelCallbacks +import org.enso.ydoc.api.YjsChannel + +import java.util.UUID + +import scala.concurrent.ExecutionContext + +/** Exposes a multi-client JSON RPC Server instance over WebSocket connections. + * + * @param protocolFactory a protocol factory + * @param clientControllerFactory a factory used to create a client controller + * @param config a server config + * @param optionalEndpoints a list of optional endpoints + * @param system an actor system + */ +class YdocJsonRpcServer( + protocolFactory: ProtocolFactory, + clientControllerFactory: ClientControllerFactory, + config: JsonRpcServer.Config = JsonRpcServer.Config.default, + optionalEndpoints: List[Endpoint] = List.empty, + messageCallbacks: List[MessageHandler.WebMessage => Unit] = List.empty +)(implicit + val system: ActorSystem +) extends Server + with LazyLogging { + + implicit val ec: ExecutionContext = system.dispatcher + + val yjsChannelCallbacks = + new YdocJsonRpcServer.ServerCallbacks( + protocolFactory, + clientControllerFactory, + messageCallbacks, + system + ) + + override protected def serverRoute(port: Int): Route = { + val emptyEndpoint = + path("__null") { + post { null } + } + + optionalEndpoints.foldLeft(emptyEndpoint) { (chain, next) => + chain ~ next.route + } + } + + override protected def secureConfig(): Option[SecureConnectionConfig] = + config.secureConfig +} + +object YdocJsonRpcServer { + + final class ServerCallbacks( + protocolFactory: ProtocolFactory, + clientControllerFactory: ClientControllerFactory, + messageCallbacks: List[MessageHandler.WebMessage => Unit], + system: ActorSystem + ) extends YjsChannelCallbacks + with LazyLogging { + + override def onConnect(channel: YjsChannel): Unit = { + logger.info(s"ServerCallbacks.onConnect ${channel.getClass()}") + System.err.println( + s" is proxy ${java.lang.reflect.Proxy.isProxyClass(channel.getClass())})" + ) + System.err.println( + s" invocation ${java.lang.reflect.Proxy.getInvocationHandler(channel)})" + ) + + val incomingMessageHandler = + system.actorOf( + Props( + new MessageHandlerSupervisor( + clientControllerFactory, + protocolFactory + ) + ), + s"message-handler-supervisor-${UUID.randomUUID()}" + ) + try { + val toSubscribe = + new OnMessageHandler(messageCallbacks, incomingMessageHandler) + channel.subscribe(toSubscribe) + } catch { + case e: Exception => + logger.error("ServerCallbacks.onConnect err", e) + } + + val outgoingMessageHandler = + system.actorOf( + Props( + new OutgoingMessageHandler(channel) + ) + ) + incomingMessageHandler ! MessageHandler.Connected(outgoingMessageHandler) + } + } + + final class OnMessageHandler( + messageCallbacks: List[MessageHandler.WebMessage => Unit], + incomingMessageHandler: ActorRef + ) extends java.util.function.Consumer[Object] + with LazyLogging { + def accept(message: Object): Unit = { + message match { + case m: String => + logger.info(s"Received message $m") + val webMessage = MessageHandler.WebMessage(m) + incomingMessageHandler ! webMessage + messageCallbacks.foreach(cb => cb(webMessage)) + case _ => + logger.error("Received unsupported message:", message) + } + } + } + + final class OutgoingMessageHandler(channel: YjsChannel) + extends Actor + with LazyLogging { + + override def receive: Receive = { + case MessageHandler.WebMessage(message) => + logger.info(s"Sending message $message") + channel.send(message) + case unknown => + logger.error("Sending unsupported message:", unknown) + } + } +} diff --git a/lib/scala/json-rpc-server/src/test/scala/org/enso/jsonrpc/MessageHandlerSpec.scala b/lib/scala/json-rpc-server/src/test/scala/org/enso/jsonrpc/MessageHandlerSpec.scala index 08ee589638c1..70a3f62df7a2 100644 --- a/lib/scala/json-rpc-server/src/test/scala/org/enso/jsonrpc/MessageHandlerSpec.scala +++ b/lib/scala/json-rpc-server/src/test/scala/org/enso/jsonrpc/MessageHandlerSpec.scala @@ -96,7 +96,7 @@ class MessageHandlerSpec handler = system.actorOf( Props(new MessageHandler(MyProtocolFactory, controller.ref)) ) - handler ! Connected(out.ref, 0) + handler ! Connected(out.ref) } "Message handler" must { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index aa646c6bbffd..b97f809d1d5a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -546,6 +546,9 @@ importers: y-websocket: specifier: ^1.5.4 version: 1.5.4(yjs@13.6.21) + ydoc-channel: + specifier: workspace:* + version: link:../ydoc-channel ydoc-shared: specifier: workspace:* version: link:../ydoc-shared @@ -866,6 +869,25 @@ importers: specifier: 'catalog:' version: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/browser@3.2.4)(jiti@1.21.7)(jsdom@26.1.0)(terser@5.37.0)(yaml@2.7.0) + app/ydoc-channel: + dependencies: + lib0: + specifier: ^0.2.99 + version: 0.2.99 + yjs: + specifier: ^13.6.19 + version: 13.6.21 + devDependencies: + '@types/node': + specifier: 'catalog:' + version: 24.2.1 + typescript: + specifier: 'catalog:' + version: 5.7.2 + vitest: + specifier: 'catalog:' + version: 3.2.4(@types/debug@4.1.12)(@types/node@24.2.1)(@vitest/browser@3.2.4)(jiti@1.21.7)(jsdom@26.1.0)(terser@5.37.0)(yaml@2.7.0) + app/ydoc-server: dependencies: debug: @@ -886,6 +908,9 @@ importers: y-protocols: specifier: ^1.0.6 version: 1.0.6(yjs@13.6.21) + ydoc-channel: + specifier: workspace:* + version: link:../ydoc-channel ydoc-shared: specifier: workspace:* version: link:../ydoc-shared @@ -920,12 +945,12 @@ importers: app/ydoc-server-polyglot: dependencies: + ydoc-channel: + specifier: workspace:* + version: link:../ydoc-channel ydoc-server: specifier: workspace:* version: link:../ydoc-server - ydoc-shared: - specifier: workspace:* - version: link:../ydoc-shared devDependencies: '@fal-works/esbuild-plugin-global-externals': specifier: ^2.1.2 @@ -987,6 +1012,9 @@ importers: rust-ffi: specifier: workspace:* version: link:../rust-ffi + ydoc-channel: + specifier: workspace:* + version: link:../ydoc-channel yjs: specifier: ^13.6.21 version: 13.6.21 diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index e8a150b4bdb2..04406f3d46fb 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -7,8 +7,8 @@ packages: - app/project-manager-shim - app/rust-ffi - app/table-expression + - app/ydoc-channel - app/ydoc-server - - app/ydoc-server-nodejs - app/ydoc-server-polyglot - app/ydoc-shared - lib/js/runner