diff --git a/core/src/commonMain/kotlin/kotlinx/rpc/RpcClient.kt b/core/src/commonMain/kotlin/kotlinx/rpc/RpcClient.kt index 5e6073320..e33c985b0 100644 --- a/core/src/commonMain/kotlin/kotlinx/rpc/RpcClient.kt +++ b/core/src/commonMain/kotlin/kotlinx/rpc/RpcClient.kt @@ -28,22 +28,6 @@ public interface RpcClient : CoroutineScope { */ public suspend fun call(call: RpcCall): T - /** - * This method is used by generated clients to perform a call to the server. - * - * @param T type of the result - * @param serviceScope service's coroutine scope - * @param call an object that contains all required information about the called method, - * that is needed to route it properly to the server. - * @return actual result of the call, for example, data from the server - */ - @Deprecated( - "This method was primarily used for fields in RPC services, which are now deprecated. " + - "See https://kotlin.github.io/kotlinx-rpc/strict-mode.html fields guide for more information", - level = DeprecationLevel.ERROR, - ) - public fun callAsync(serviceScope: CoroutineScope, call: RpcCall): Deferred - /** * This method is used by generated clients to perform a call to the server * that returns a streaming flow. @@ -53,9 +37,7 @@ public interface RpcClient : CoroutineScope { * that is needed to route it properly to the server. * @return the actual result of the call, for example, data from the server */ - public fun callServerStreaming(call: RpcCall): Flow { - error("Non-suspending server streaming is not supported by this client") - } + public fun callServerStreaming(call: RpcCall): Flow /** * Provides child [CoroutineContext] for a new [RemoteService] service stub. diff --git a/krpc/krpc-client/build.gradle.kts b/krpc/krpc-client/build.gradle.kts index 8b91ed2e6..195b9dabd 100644 --- a/krpc/krpc-client/build.gradle.kts +++ b/krpc/krpc-client/build.gradle.kts @@ -28,14 +28,3 @@ kotlin { } } -rpc { - strict { - stateFlow = RpcStrictMode.NONE - sharedFlow = RpcStrictMode.NONE - nestedFlow = RpcStrictMode.NONE - streamScopedFunctions = RpcStrictMode.NONE - suspendingServerStreaming = RpcStrictMode.NONE - notTopLevelServerFlow = RpcStrictMode.NONE - fields = RpcStrictMode.NONE - } -} diff --git a/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt index 94f1c630b..dff4a949a 100644 --- a/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt +++ b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/KrpcClient.kt @@ -8,25 +8,35 @@ import kotlinx.atomicfu.atomic import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.rpc.RpcCall import kotlinx.rpc.RpcClient import kotlinx.rpc.annotations.Rpc import kotlinx.rpc.descriptor.RpcCallable -import kotlinx.rpc.internal.serviceScopeOrNull +import kotlinx.rpc.descriptor.RpcInvokator import kotlinx.rpc.internal.utils.InternalRpcApi -import kotlinx.rpc.internal.utils.RpcInternalSupervisedCompletableDeferred import kotlinx.rpc.internal.utils.getOrNull import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap import kotlinx.rpc.krpc.* +import kotlinx.rpc.krpc.client.internal.ClientStreamContext +import kotlinx.rpc.krpc.client.internal.ClientStreamSerializer import kotlinx.rpc.krpc.client.internal.KrpcClientConnector +import kotlinx.rpc.krpc.client.internal.StreamCall import kotlinx.rpc.krpc.internal.* import kotlinx.rpc.krpc.internal.logging.RpcInternalCommonLogger import kotlinx.serialization.BinaryFormat +import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialFormat import kotlinx.serialization.StringFormat +import kotlinx.serialization.modules.SerializersModule +import kotlin.collections.first import kotlin.coroutines.CoroutineContext import kotlin.coroutines.cancellation.CancellationException +import kotlin.time.Duration.Companion.seconds @Deprecated("Use KrpcClient instead", ReplaceWith("KrpcClient"), level = DeprecationLevel.ERROR) public typealias KRPCClient = KrpcClient @@ -50,9 +60,9 @@ public typealias KRPCClient = KrpcClient */ @OptIn(InternalCoroutinesApi::class) public abstract class KrpcClient( - final override val config: KrpcConfig.Client, + private val config: KrpcConfig.Client, transport: KrpcTransport, -) : KrpcServiceHandler(), RpcClient, KrpcEndpoint { +) : RpcClient, KrpcEndpoint { // we make a child here, so we can send cancellation messages before closing the connection final override val coroutineContext: CoroutineContext = SupervisorJob(transport.coroutineContext.job) @@ -68,7 +78,7 @@ public abstract class KrpcClient( private val callCounter = atomic(0L) - final override val logger: RpcInternalCommonLogger = RpcInternalCommonLogger.logger(rpcInternalObjectId()) + private val logger: RpcInternalCommonLogger = RpcInternalCommonLogger.logger(rpcInternalObjectId()) private val serverSupportedPlugins: CompletableDeferred> = CompletableDeferred() @@ -96,7 +106,7 @@ public abstract class KrpcClient( @OptIn(InternalCoroutinesApi::class) @InternalRpcApi final override fun provideStubContext(serviceId: Long): CoroutineContext { - val childContext = SupervisorJob(coroutineContext.job).withClientStreamScope() + val childContext = SupervisorJob(coroutineContext.job) childContext.job.invokeOnCompletion(onCancelling = true) { if (!clientCancelled) { @@ -144,175 +154,10 @@ public abstract class KrpcClient( } } - @Deprecated( - "This method was primarily used for fields in RPC services, which are now deprecated. " + - "See https://kotlin.github.io/kotlinx-rpc/strict-mode.html fields guide for more information" - ) - override fun callAsync( - serviceScope: CoroutineScope, - call: RpcCall, - ): Deferred { - val callable = call.descriptor.getCallable(call.callableName) - ?: error("Unexpected callable '${call.callableName}' for ${call.descriptor.fqName} service") - - val deferred = RpcInternalSupervisedCompletableDeferred(serviceScope.coroutineContext.job) - - /** - * Launched on the service scope (receiver) - * Moreover, this scope has [StreamScope] that is used to handle field streams. - * [StreamScope] is provided to a service via [provideStubContext]. - */ - serviceScope.launch { - val rpcCall = call(call, callable, deferred) - - deferred.invokeOnCompletion { cause -> - if (cause == null) { - rpcCall.streamContext.valueOrNull?.launchIf({ incomingHotFlowsAvailable }) { - handleIncomingHotFlows(it) - } - } - } - } - - return deferred - } - final override suspend fun call(call: RpcCall): T { - val callable = call.descriptor.getCallable(call.callableName) - ?: error("Unexpected callable '${call.callableName}' for ${call.descriptor.fqName} service") - - val callCompletableResult = RpcInternalSupervisedCompletableDeferred() - val rpcCall = call(call, callable, callCompletableResult) - val result = callCompletableResult.await() - - // incomingHotFlowsAvailable value is known after await - rpcCall.streamContext.valueOrNull?.launchIf({ incomingHotFlowsAvailable }) { - handleIncomingHotFlows(it) - } - - return result - } - - private suspend fun call( - call: RpcCall, - callable: RpcCallable<*>, - callResult: CompletableDeferred, - ): RpcCallStreamContextFormatAndId { - val wrappedCallResult = RequestCompletableDeferred(callResult) - val rpcCall = prepareAndExecuteCall(call, callable, wrappedCallResult) - - rpcCall.streamContext.valueOrNull?.launchIf({ outgoingStreamsAvailable }) { - handleOutgoingStreams(it, rpcCall.serialFormat, call.descriptor.fqName) - } - - val handle = serviceScopeOrNull()?.run { - serviceCoroutineScope.coroutineContext.job.invokeOnCompletion(onCancelling = true) { cause -> - // service can only be canceled, it can't complete successfully - callResult.completeExceptionally(CancellationException(cause)) - - rpcCall.streamContext.valueOrNull?.cancel("Service cancelled", cause) - } - } - - callResult.invokeOnCompletion { cause -> - if (cause != null) { - cancellingRequests[rpcCall.callId] = call.descriptor.fqName - - rpcCall.streamContext.valueOrNull?.cancel("Request failed", cause) - - if (!wrappedCallResult.callExceptionOccurred) { - sendCancellation(CancellationType.REQUEST, call.serviceId.toString(), rpcCall.callId) - } - - handle?.dispose() - } else { - val streamScope = rpcCall.streamContext.valueOrNull?.streamScope - - if (streamScope == null) { - handle?.dispose() - - connector.unsubscribeFromMessages(call.descriptor.fqName, rpcCall.callId) - } - - streamScope?.onScopeCompletion(rpcCall.callId) { - handle?.dispose() - - cancellingRequests[rpcCall.callId] = call.descriptor.fqName - - sendCancellation(CancellationType.REQUEST, call.serviceId.toString(), rpcCall.callId) - } - } - } - - return rpcCall + return callServerStreaming(call).first() } - private suspend fun prepareAndExecuteCall( - call: RpcCall, - callable: RpcCallable<*>, - callResult: RequestCompletableDeferred<*>, - ): RpcCallStreamContextFormatAndId { - // we should wait for the handshake to finish - awaitHandshakeCompletion() - - val id = callCounter.incrementAndGet() - - val callId = "$connectionId:${callable.name}:$id" - - logger.trace { "start a call[$callId] ${callable.name}" } - - val fallbackScope = serviceScopeOrNull() - ?.serviceCoroutineScope - ?.let { streamScopeOrNull(it) } - - val streamContext = LazyKrpcStreamContext(streamScopeOrNull(), fallbackScope) { - KrpcStreamContext(callId, config, connectionId, call.serviceId, it) - } - val serialFormat = prepareSerialFormat(streamContext) - val firstMessage = serializeRequest(callId, call, callable, serialFormat) - - @Suppress("UNCHECKED_CAST") - executeCall( - callId = callId, - streamContext = streamContext, - call = call, - callable = callable, - firstMessage = firstMessage, - serialFormat = serialFormat, - callResult = callResult as RequestCompletableDeferred - ) - - return RpcCallStreamContextFormatAndId(streamContext, serialFormat, callId) - } - - private data class RpcCallStreamContextFormatAndId( - val streamContext: LazyKrpcStreamContext, - val serialFormat: SerialFormat, - val callId: String, - ) - - private suspend fun executeCall( - callId: String, - streamContext: LazyKrpcStreamContext, - call: RpcCall, - callable: RpcCallable<*>, - firstMessage: KrpcCallMessage, - serialFormat: SerialFormat, - callResult: RequestCompletableDeferred, - ) { - connector.subscribeToCallResponse(call.descriptor.fqName, callId) { message -> - if (cancellingRequests.containsKey(callId)) { - return@subscribeToCallResponse - } - - handleMessage(message, streamContext, callable, serialFormat, callResult) - } - - connector.sendMessage(firstMessage) - } - - private val noFlowSerialFormat = config.serialFormatInitializer.build() - @Suppress("detekt.CyclomaticComplexMethod") override fun callServerStreaming(call: RpcCall): Flow { return flow { @@ -326,15 +171,7 @@ public abstract class KrpcClient( val channel = Channel() - val streamScope = StreamScope(currentCoroutineContext()) - try { - val streamContext = LazyKrpcStreamContext(streamScope, null) { - KrpcStreamContext(callId, config, connectionId, call.serviceId, it) - } - - val serialFormat = prepareSerialFormat(streamContext) - val request = serializeRequest( callId = callId, call = call, @@ -349,19 +186,23 @@ public abstract class KrpcClient( handleServerStreamingMessage(message, channel, callable, call, callId) } - streamContext.valueOrNull?.launchIf({ outgoingStreamsAvailable }) { - handleOutgoingStreams(it, serialFormat, call.descriptor.fqName) - } - - while (true) { - val element = channel.receiveCatching() - if (element.isClosed) { - val ex = element.exceptionOrNull() ?: break - throw ex + coroutineScope { + val clientStreamsJob = launch(CoroutineName("client-stream-root-${call.serviceId}-$callId")) { + supervisorScope { + clientStreamContext.streams[callId].orEmpty().forEach { + launch(CoroutineName("client-stream-${call.serviceId}-$callId-${it.streamId}")) { + handleOutgoingStream(it, serialFormat, call.descriptor.fqName) + } + } + } + println("finished client streams job") } - if (!element.isFailure) { - emit(element.getOrThrow()) + try { + consumeAndEmitServerMessages(channel) + } finally { + clientStreamsJob.cancelAndJoin() + clientStreamContext.streams.remove(callId) } } } catch (e: CancellationException) { @@ -371,12 +212,25 @@ public abstract class KrpcClient( throw e } finally { - streamScope.close() channel.close() } } } + private suspend fun FlowCollector.consumeAndEmitServerMessages(channel: Channel) { + while (true) { + val element = channel.receiveCatching() + if (element.isClosed) { + val ex = element.exceptionOrNull() ?: break + throw ex + } + + if (!element.isFailure) { + emit(element.getOrThrow()) + } + } + } + private suspend fun handleServerStreamingMessage( message: KrpcCallMessage, channel: Channel, @@ -405,10 +259,10 @@ public abstract class KrpcClient( is KrpcCallMessage.CallSuccess, is KrpcCallMessage.StreamMessage -> { val value = runCatching { - val serializerResult = noFlowSerialFormat.serializersModule + val serializerResult = serialFormat.serializersModule .rpcSerializerForType(callable.returnType) - decodeMessageData(noFlowSerialFormat, serializerResult, message) + decodeMessageData(serialFormat, serializerResult, message) } @Suppress("UNCHECKED_CAST") @@ -428,57 +282,6 @@ public abstract class KrpcClient( } } - private suspend fun handleMessage( - message: KrpcCallMessage, - streamContext: LazyKrpcStreamContext, - callable: RpcCallable<*>, - serialFormat: SerialFormat, - callResult: RequestCompletableDeferred, - ) { - when (message) { - is KrpcCallMessage.CallData -> { - error("Unexpected message") - } - - is KrpcCallMessage.CallException -> { - val cause = runCatching { - message.cause.deserialize() - } - - val result = if (cause.isFailure) { - cause.exceptionOrNull()!! - } else { - cause.getOrNull()!! - } - - callResult.callExceptionOccurred = true - callResult.completeExceptionally(result) - } - - is KrpcCallMessage.CallSuccess -> { - val value = runCatching { - val serializerResult = serialFormat.serializersModule.rpcSerializerForType(callable.returnType) - - decodeMessageData(serialFormat, serializerResult, message) - } - - callResult.completeWith(value) - } - - is KrpcCallMessage.StreamCancel -> { - streamContext.awaitInitialized().cancelStream(message) - } - - is KrpcCallMessage.StreamFinished -> { - streamContext.awaitInitialized().closeStream(message) - } - - is KrpcCallMessage.StreamMessage -> { - streamContext.awaitInitialized().send(message, serialFormat) - } - } - } - @InternalRpcApi final override suspend fun handleCancellation(message: KrpcGenericMessage) { when (val type = message.cancellationType()) { @@ -513,7 +316,10 @@ public abstract class KrpcClient( val serializerData = serialFormat.serializersModule.rpcSerializerForType(callable.dataType) return when (serialFormat) { is StringFormat -> { - val stringValue = serialFormat.encodeToString(serializerData, call.data) + val stringValue = clientStreamContext.scoped(callId, call.serviceId) { + serialFormat.encodeToString(serializerData, call.data) + } + KrpcCallMessage.CallDataString( callId = callId, serviceType = call.descriptor.fqName, @@ -527,7 +333,10 @@ public abstract class KrpcClient( } is BinaryFormat -> { - val binaryValue = serialFormat.encodeToByteArray(serializerData, call.data) + val binaryValue = clientStreamContext.scoped(callId, call.serviceId) { + serialFormat.encodeToByteArray(serializerData, call.data) + } + KrpcCallMessage.CallDataBinary( callId = callId, serviceType = call.descriptor.fqName, @@ -545,8 +354,106 @@ public abstract class KrpcClient( } } } -} -private class RequestCompletableDeferred(delegate: CompletableDeferred) : CompletableDeferred by delegate { - var callExceptionOccurred: Boolean = false + private suspend fun handleOutgoingStream( + outgoingStream: StreamCall, + serialFormat: SerialFormat, + serviceTypeString: String, + ) { + try { + collectAndSendOutgoingStream( + serialFormat = serialFormat, + flow = outgoingStream.stream, + outgoingStream = outgoingStream, + serviceTypeString = serviceTypeString, + ) + } catch (e: CancellationException) { + throw e + } catch (@Suppress("detekt.TooGenericExceptionCaught") cause: Throwable) { + val serializedReason = serializeException(cause) + val message = KrpcCallMessage.StreamCancel( + callId = outgoingStream.callId, + serviceType = serviceTypeString, + streamId = outgoingStream.streamId, + cause = serializedReason, + connectionId = outgoingStream.connectionId, + serviceId = outgoingStream.serviceId, + ) + sender.sendMessage(message) + + throw cause + } + + val message = KrpcCallMessage.StreamFinished( + callId = outgoingStream.callId, + serviceType = serviceTypeString, + streamId = outgoingStream.streamId, + connectionId = outgoingStream.connectionId, + serviceId = outgoingStream.serviceId, + ) + + sender.sendMessage(message) + } + + private suspend fun collectAndSendOutgoingStream( + serialFormat: SerialFormat, + flow: Flow<*>, + serviceTypeString: String, + outgoingStream: StreamCall, + ) { + flow.collect { + println("collected: $it, ${outgoingStream.streamId}") + val message = when (serialFormat) { + is StringFormat -> { + val stringData = serialFormat.encodeToString(outgoingStream.elementSerializer, it) + KrpcCallMessage.StreamMessageString( + callId = outgoingStream.callId, + serviceType = serviceTypeString, + streamId = outgoingStream.streamId, + data = stringData, + connectionId = outgoingStream.connectionId, + serviceId = outgoingStream.serviceId, + ) + } + + is BinaryFormat -> { + val binaryData = serialFormat.encodeToByteArray(outgoingStream.elementSerializer, it) + KrpcCallMessage.StreamMessageBinary( + callId = outgoingStream.callId, + serviceType = serviceTypeString, + streamId = outgoingStream.streamId, + data = binaryData, + connectionId = outgoingStream.connectionId, + serviceId = outgoingStream.serviceId, + ) + } + + else -> { + unsupportedSerialFormatError(serialFormat) + } + } + + sender.sendMessage(message) + } + } + + private val clientStreamContext: ClientStreamContext = ClientStreamContext(connectionId = connectionId) + + private val serialFormat: SerialFormat by lazy { + val module = SerializersModule { + contextual(Flow::class) { + @Suppress("UNCHECKED_CAST") + ClientStreamSerializer(clientStreamContext, it.first() as KSerializer) + } + } + + config.serialFormatInitializer.applySerializersModuleAndBuild(module) + } + + private fun RpcCallable<*>.toMessageCallType(): KrpcCallMessage.CallType { + return when (invokator) { + is RpcInvokator.Method -> KrpcCallMessage.CallType.Method + is RpcInvokator.Field -> KrpcCallMessage.CallType.Field + } + } } diff --git a/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/ClientStreamContext.kt b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/ClientStreamContext.kt new file mode 100644 index 000000000..2ca48eea8 --- /dev/null +++ b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/ClientStreamContext.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.krpc.client.internal + +import kotlinx.atomicfu.atomic +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap +import kotlinx.serialization.KSerializer +import kotlin.native.concurrent.ThreadLocal + +internal class ClientStreamContext(private val connectionId: Long?) { + val streams = RpcInternalConcurrentHashMap>() + + @ThreadLocal + private var currentCallId: String? = null + + @ThreadLocal + private var currentServiceId: Long? = null + + fun scoped(callId: String, serviceId: Long, body: () -> T): T { + try { + currentCallId = callId + currentServiceId = serviceId + return body() + } finally { + currentCallId = null + currentServiceId = null + } + } + + private val streamIdCounter = atomic(0L) + + fun registerClientStream(value: Flow<*>, elementKind: KSerializer<*>): String { + val callId = currentCallId ?: error("No call id") + val serviceId = currentServiceId ?: error("No service id") + val streamId = "$STREAM_ID_PREFIX${streamIdCounter.getAndIncrement()}" + + @Suppress("UNCHECKED_CAST") + val stream = StreamCall( + callId = callId, + streamId = streamId, + stream = value, + elementSerializer = elementKind as KSerializer, + connectionId = connectionId, + serviceId = serviceId + ) + + @Suppress("UNCHECKED_CAST") + streams.merge(callId, listOf(stream)) { old, new -> old + new } + + return streamId + } + + private companion object { + private const val STREAM_ID_PREFIX = "stream:" + } +} diff --git a/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/ClientStreamSerializer.kt b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/ClientStreamSerializer.kt new file mode 100644 index 000000000..d677dbf7f --- /dev/null +++ b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/ClientStreamSerializer.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.krpc.client.internal + +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.krpc.internal.StreamSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder + +internal class ClientStreamSerializer( + val context: ClientStreamContext, + val elementType: KSerializer, +) : KSerializer>, StreamSerializer() { + override val descriptor: SerialDescriptor by lazy { + buildClassSerialDescriptor("ClientStreamSerializer") { + element(STREAM_ID_SERIAL_NAME, streamIdDescriptor) + } + } + + override fun deserialize(decoder: Decoder): Flow<*> { + error("This method must not be called. Please report to the developer.") + } + + override fun serialize(encoder: Encoder, value: Flow<*>) { + val id = context.registerClientStream(value, elementType) + + encoder.encodeString(id) + } +} diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcStreamCall.kt b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/StreamCall.kt similarity index 56% rename from krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcStreamCall.kt rename to krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/StreamCall.kt index 6d22b4f25..4c4a556c5 100644 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcStreamCall.kt +++ b/krpc/krpc-client/src/commonMain/kotlin/kotlinx/rpc/krpc/client/internal/StreamCall.kt @@ -1,16 +1,16 @@ /* - * Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -package kotlinx.rpc.krpc.internal +package kotlinx.rpc.krpc.client.internal +import kotlinx.coroutines.flow.Flow import kotlinx.serialization.KSerializer -internal data class KrpcStreamCall( +internal data class StreamCall( val callId: String, val streamId: String, - val stream: Any, - val kind: StreamKind, + val stream: Flow<*>, val elementSerializer: KSerializer, val connectionId: Long?, val serviceId: Long?, diff --git a/krpc/krpc-core/build.gradle.kts b/krpc/krpc-core/build.gradle.kts index 29040a7dc..dd7ff2a67 100644 --- a/krpc/krpc-core/build.gradle.kts +++ b/krpc/krpc-core/build.gradle.kts @@ -28,15 +28,3 @@ kotlin { } } } - -rpc { - strict { - stateFlow = RpcStrictMode.NONE - sharedFlow = RpcStrictMode.NONE - nestedFlow = RpcStrictMode.NONE - streamScopedFunctions = RpcStrictMode.NONE - suspendingServerStreaming = RpcStrictMode.NONE - notTopLevelServerFlow = RpcStrictMode.NONE - fields = RpcStrictMode.NONE - } -} diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/KrpcConfig.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/KrpcConfig.kt index 962e965e1..db789b1c8 100644 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/KrpcConfig.kt +++ b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/KrpcConfig.kt @@ -4,9 +4,6 @@ package kotlinx.rpc.krpc -import kotlinx.coroutines.channels.BufferOverflow -import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.SharedFlow import kotlinx.rpc.internal.utils.InternalRpcApi import kotlinx.rpc.krpc.serialization.KrpcSerialFormat import kotlinx.rpc.krpc.serialization.KrpcSerialFormatBuilder @@ -19,88 +16,6 @@ public typealias RPCConfigBuilder = KrpcConfigBuilder * Builder for [KrpcConfig]. Provides DSL to configure parameters for KrpcClient and/or KrpcServer. */ public sealed class KrpcConfigBuilder private constructor() { - /** - * DSL for parameters of [MutableSharedFlow] and [SharedFlow]. - * - * This is a temporary solution that hides the problem of transferring these parameters. - * [SharedFlow] and [MutableSharedFlow] do not define theirs 'replay', 'extraBufferCapacity' and 'onBufferOverflow' - * parameters, and thus they cannot be encoded and transferred. - * So then creating their instance on an endpoint, the library should know which parameters to use. - */ - @Deprecated( - "SharedFlow support is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-5-0.html", - level = DeprecationLevel.WARNING, - ) - @Suppress("MemberVisibilityCanBePrivate") - public class SharedFlowParametersBuilder internal constructor() { - /** - * The number of values replayed to new subscribers (cannot be negative, defaults to zero). - */ - @Deprecated( - "SharedFlow support is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-5-0.html", - level = DeprecationLevel.WARNING, - ) - public var replay: Int = DEFAULT_REPLAY - - /** - * The number of values buffered in addition to replay. - * emit does not suspend while there is a buffer space remaining - * (optional, cannot be negative, defaults to zero). - */ - @Deprecated( - "SharedFlow support is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-5-0.html", - level = DeprecationLevel.WARNING, - ) - public var extraBufferCapacity: Int = DEFAULT_EXTRA_BUFFER_CAPACITY - - /** - * Configures an emit action on buffer overflow. - * Optional, defaults to suspending attempts to emit a value. - * Values other than [BufferOverflow.SUSPEND] are supported only when replay > 0 or extraBufferCapacity > 0. - * Buffer overflow can happen only when there is at least one subscriber - * that is not ready to accept the new value. - * In the absence of subscribers only the most recent replay values are stored - * and the buffer overflow behavior is never triggered and has no effect. - */ - @Deprecated( - "SharedFlow support is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-5-0.html", - level = DeprecationLevel.WARNING, - ) - public var onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND - - @InternalRpcApi - public fun builder(): () -> MutableSharedFlow = { - @Suppress("DEPRECATION") - MutableSharedFlow(replay, extraBufferCapacity, onBufferOverflow) - } - - private companion object { - /** - * Default value of [replay] - */ - const val DEFAULT_REPLAY = 1 - /** - * Default value of [extraBufferCapacity] - */ - const val DEFAULT_EXTRA_BUFFER_CAPACITY = 10 - } - } - - @Suppress("DEPRECATION") - protected var sharedFlowBuilder: () -> MutableSharedFlow = SharedFlowParametersBuilder().builder() - - /** - * @see SharedFlowParametersBuilder - */ - @Deprecated( - "SharedFlow support is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-5-0.html", - level = DeprecationLevel.WARNING, - ) - public fun sharedFlowParameters(builder: @Suppress("DEPRECATION") SharedFlowParametersBuilder.() -> Unit) { - @Suppress("DEPRECATION") - sharedFlowBuilder = SharedFlowParametersBuilder().apply(builder).builder() - } - private var serialFormatInitializer: KrpcSerialFormatBuilder<*, *>? = null private val configuration = object : KrpcSerialFormatConfiguration { @@ -153,7 +68,6 @@ public sealed class KrpcConfigBuilder private constructor() { public class Client : KrpcConfigBuilder() { public fun build(): KrpcConfig.Client { return KrpcConfig.Client( - sharedFlowBuilder = sharedFlowBuilder, serialFormatInitializer = rpcSerialFormat(), waitForServices = waitForServices, ) @@ -166,7 +80,6 @@ public sealed class KrpcConfigBuilder private constructor() { public class Server : KrpcConfigBuilder() { public fun build(): KrpcConfig.Server { return KrpcConfig.Server( - sharedFlowBuilder = sharedFlowBuilder, serialFormatInitializer = rpcSerialFormat(), waitForServices = waitForServices, ) @@ -181,8 +94,6 @@ public typealias RPCConfig = KrpcConfig * Configuration class that is used by kRPC protocol's client and server (KrpcClient and KrpcServer). */ public sealed interface KrpcConfig { - @InternalRpcApi - public val sharedFlowBuilder: () -> MutableSharedFlow @InternalRpcApi public val serialFormatInitializer: KrpcSerialFormatBuilder<*, *> @InternalRpcApi @@ -192,7 +103,6 @@ public sealed interface KrpcConfig { * @see [KrpcConfig] */ public class Client internal constructor( - override val sharedFlowBuilder: () -> MutableSharedFlow, override val serialFormatInitializer: KrpcSerialFormatBuilder<*, *>, override val waitForServices: Boolean, ) : KrpcConfig @@ -201,7 +111,6 @@ public sealed interface KrpcConfig { * @see [KrpcConfig] */ public class Server internal constructor( - override val sharedFlowBuilder: () -> MutableSharedFlow, override val serialFormatInitializer: KrpcSerialFormatBuilder<*, *>, override val waitForServices: Boolean, ) : KrpcConfig diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt deleted file mode 100644 index 90b75b662..000000000 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/StreamScope.kt +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. - */ - -package kotlinx.rpc.krpc - -import kotlinx.coroutines.* -import kotlinx.rpc.internal.utils.ExperimentalRpcApi -import kotlinx.rpc.internal.utils.InternalRpcApi -import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract -import kotlin.coroutines.CoroutineContext -import kotlin.coroutines.coroutineContext -import kotlin.js.JsName - -/** - * Stream scope handles all RPC streams that are launched inside it. - * Streams are alive until stream scope is. Streams can outlive their initial request scope. - * - * Streams are grouped by the request that initiated them. - * Each group can have a completion callback associated with it. - * - * Stream scope is a child of the [CoroutineContext] it was created in. - * Failure of one request will not cancel all streams in the others. - */ -@OptIn(InternalCoroutinesApi::class) -@Deprecated( - "StreamScope is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-6-0.html", - level = DeprecationLevel.WARNING -) -public class StreamScope internal constructor( - parentContext: CoroutineContext, - internal val role: Role, -) : AutoCloseable { - internal class Element(internal val scope: StreamScope) : CoroutineContext.Element { - override val key: CoroutineContext.Key = Key - - internal companion object Key : CoroutineContext.Key - } - - internal val contextElement = Element(this) - - private val scopeJob = SupervisorJob(parentContext.job) - - private val requests = RpcInternalConcurrentHashMap() - - init { - scopeJob.invokeOnCompletion { - close() - } - } - - @InternalRpcApi - public fun onScopeCompletion(handler: (Throwable?) -> Unit) { - scopeJob.invokeOnCompletion(handler) - } - - @InternalRpcApi - public fun onScopeCompletion(callId: String, handler: (Throwable?) -> Unit) { - getRequestScope(callId).coroutineContext.job.invokeOnCompletion(onCancelling = true, handler = handler) - } - - @InternalRpcApi - public fun cancelRequestScopeById(callId: String, message: String, cause: Throwable?): Job? { - return requests.remove(callId)?.apply { cancel(message, cause) }?.coroutineContext?.job - } - - // Group stream launches by callId. In case one fails, so do others - @InternalRpcApi - public fun launch(callId: String, block: suspend CoroutineScope.() -> Unit): Job { - return getRequestScope(callId).launch(block = block) - } - - override fun close() { - scopeJob.cancel("Stream scope closed") - requests.clear() - } - - private fun getRequestScope(callId: String): CoroutineScope { - return requests.computeIfAbsent(callId) { CoroutineScope(Job(scopeJob.job)) } - } - - internal class CallScope(val callId: String) : CoroutineContext.Element { - object Key : CoroutineContext.Key - - override val key: CoroutineContext.Key<*> = Key - } - - @InternalRpcApi - public enum class Role { - Client, Server; - } -} - -@InternalRpcApi -public fun CoroutineContext.withClientStreamScope(): CoroutineContext = withStreamScope(StreamScope.Role.Client) - -@InternalRpcApi -public fun CoroutineContext.withServerStreamScope(): CoroutineContext = withStreamScope(StreamScope.Role.Server) - -@OptIn(InternalCoroutinesApi::class) -internal fun CoroutineContext.withStreamScope(role: StreamScope.Role): CoroutineContext { - return this + StreamScope(this, role).contextElement.apply { - this@withStreamScope.job.invokeOnCompletion(onCancelling = true) { scope.close() } - } -} - -@InternalRpcApi -public suspend fun streamScopeOrNull(): StreamScope? { - return currentCoroutineContext()[StreamScope.Element.Key]?.scope -} - -@InternalRpcApi -public fun streamScopeOrNull(scope: CoroutineScope): StreamScope? { - return scope.coroutineContext[StreamScope.Element.Key]?.scope -} - -internal fun noStreamScopeError(): Nothing { - error( - "Stream scopes can only be used inside the 'streamScoped' block. \n" + - "To use stream scope API on a client - wrap your call with 'streamScoped' block.\n" + - "To use stream scope API on a server - use must use 'streamScoped' block for this call on a client." - ) -} - -@InternalRpcApi -public suspend fun callScoped(callId: String, block: suspend CoroutineScope.() -> T): T { - val context = currentCoroutineContext() - - if (context[StreamScope.CallScope.Key] != null) { - error("Nested callScoped calls are not allowed") - } - - val callScope = StreamScope.CallScope(callId) - - return withContext(callScope, block) -} - -/** - * Defines lifetime for all RPC streams that are used inside it. - * When the [block] ends - all streams that were created inside it are canceled. - * The same happens when an exception is thrown. - * - * All RPC calls that use streams, either sending or receiving them, - * MUST use this scope to define their lifetime. - * - * Lifetimes inside [streamScoped] are hierarchical, - * meaning that there is parent lifetime for all calls inside this block, - * and each call has its own lifetime independent of others. - * This also means that all streams from one call share the same lifetime. - * - * Examples: - * ```kotlin - * streamScoped { - * val flow = flow { /* ... */ } - * service.sendStream(flow) // will stop sending updates when 'streamScoped' block is finished - * } - * ``` - * - * ```kotlin - * streamScoped { - * launch { - * val flow1 = flow { /* ... */ } - * service.sendStream(flow) - * } - * - * // if call with 'flow1' is canceled or failed - this flow will continue working - * launch { - * val flow2 = flow { /* ... */ } - * service.sendStream(flow) - * } - * } - * ``` - */ -@Deprecated( - "streamScoped is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-6-0.html", - level = DeprecationLevel.WARNING -) -@OptIn(ExperimentalContracts::class) -public suspend fun streamScoped(block: suspend CoroutineScope.() -> T): T { - contract { - callsInPlace(block, InvocationKind.EXACTLY_ONCE) - } - - val context = currentCoroutineContext() - .apply { - checkContextForStreamScope() - } - - val streamScope = StreamScope(context, StreamScope.Role.Client) - - return withContext(streamScope.contextElement) { - streamScope.use { - block() - } - } -} - -private fun CoroutineContext.checkContextForStreamScope() { - if (this[StreamScope.Element] != null) { - error( - "One of the following caused a failure: \n" + - "- nested 'streamScoped' or `withStreamScope` calls are not allowed.\n" + - "- 'streamScoped' or `withStreamScope` calls are not allowed in server RPC services." - ) - } -} - -/** - * Creates a [StreamScope] entity for manual stream management. - */ -@JsName("StreamScope_fun") -@ExperimentalRpcApi -@Deprecated( - "StreamScoped is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-6-0.html", - level = DeprecationLevel.WARNING -) -public fun StreamScope(parent: CoroutineContext): StreamScope { - parent.checkContextForStreamScope() - - return StreamScope(parent, StreamScope.Role.Client) -} - -/** - * Adds manually managed [StreamScope] to the current context. - */ -@OptIn(ExperimentalContracts::class) -@ExperimentalRpcApi -@Deprecated( - "withStreamScope is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-6-0.html", - level = DeprecationLevel.WARNING -) -public suspend fun withStreamScope(scope: StreamScope, block: suspend CoroutineScope.() -> T): T { - contract { - callsInPlace(block, InvocationKind.EXACTLY_ONCE) - } - - currentCoroutineContext().checkContextForStreamScope() - - return withContext(scope.contextElement, block) -} - -/** - * This is a callback that will run when stream scope (created by [streamScoped] function) ends. - * Typically, this is used to release stream resources that may be occupied by a call: - * ```kotlin - * // service on server - * override suspend fun returnStateFlow(): StateFlow { - * val state = MutableStateFlow(-1) - * - * incomingHotFlowJob = launch { - * repeat(Int.MAX_VALUE) { value -> - * state.value = value - * - * delay(1000) // intense work - * } - * } - * - * // release resources allocated for state flow, when it is closed on the client - * invokeOnStreamScopeCompletion { - * incomingHotFlowJob.cancel() - * } - * - * return state - * } - * ``` - */ -@ExperimentalRpcApi -@Deprecated( - "invokeOnStreamScopeCompletion is deprecated, see https://kotlin.github.io/kotlinx-rpc/0-6-0.html", - level = DeprecationLevel.WARNING -) -public suspend fun invokeOnStreamScopeCompletion(throwIfNoScope: Boolean = true, block: (Throwable?) -> Unit) { - val streamScope = streamScopeOrNull() ?: noStreamScopeError() - - if (streamScope.role == StreamScope.Role.Client) { - streamScope.onScopeCompletion(block) - return - } - - val callScope = coroutineContext[StreamScope.CallScope.Key] - - when { - callScope != null -> streamScope.onScopeCompletion(callScope.callId, block) - - throwIfNoScope -> error( - "'invokeOnStreamScopeCompletion' can only be called with corresponding 'streamScoped' block on a client" - ) - } -} diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcServiceHandler.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcServiceHandler.kt deleted file mode 100644 index 487327fd6..000000000 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcServiceHandler.kt +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. - */ - -package kotlinx.rpc.krpc.internal - -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.SharedFlow -import kotlinx.coroutines.flow.StateFlow -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock -import kotlinx.rpc.descriptor.RpcCallable -import kotlinx.rpc.descriptor.RpcInvokator -import kotlinx.rpc.internal.utils.InternalRpcApi -import kotlinx.rpc.krpc.KrpcConfig -import kotlinx.rpc.krpc.internal.logging.RpcInternalCommonLogger -import kotlinx.serialization.BinaryFormat -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialFormat -import kotlinx.serialization.StringFormat -import kotlinx.serialization.modules.SerializersModule - -@InternalRpcApi -public abstract class KrpcServiceHandler { - protected abstract val sender: KrpcMessageSender - protected abstract val config: KrpcConfig - protected abstract val logger: RpcInternalCommonLogger - - protected suspend fun handleIncomingHotFlows(streamContext: KrpcStreamContext) { - for (hotFlow in streamContext.incomingHotFlows) { - streamContext.launch { - /** Start consuming incoming requests, see [KrpcIncomingHotFlow.emit] */ - hotFlow.emit(null) - } - } - } - - protected suspend fun handleOutgoingStreams( - streamContext: KrpcStreamContext, - serialFormat: SerialFormat, - serviceTypeString: String, - ) { - val mutex = Mutex() - for (outgoingStream in streamContext.outgoingStreams) { - streamContext.launch { - try { - when (outgoingStream.kind) { - StreamKind.Flow, StreamKind.SharedFlow, StreamKind.StateFlow -> { - val stream = outgoingStream.stream as Flow<*> - - collectAndSendOutgoingStream( - mutex = mutex, - serialFormat = serialFormat, - flow = stream, - outgoingStream = outgoingStream, - serviceTypeString = serviceTypeString, - ) - } - } - } catch (e : CancellationException) { - // canceled by a streamScope - throw e - } catch (@Suppress("detekt.TooGenericExceptionCaught") cause: Throwable) { - mutex.withLock { - val serializedReason = serializeException(cause) - val message = KrpcCallMessage.StreamCancel( - callId = outgoingStream.callId, - serviceType = serviceTypeString, - streamId = outgoingStream.streamId, - cause = serializedReason, - connectionId = outgoingStream.connectionId, - serviceId = outgoingStream.serviceId, - ) - sender.sendMessage(message) - } - throw cause - } - - mutex.withLock { - val message = KrpcCallMessage.StreamFinished( - callId = outgoingStream.callId, - serviceType = serviceTypeString, - streamId = outgoingStream.streamId, - connectionId = outgoingStream.connectionId, - serviceId = outgoingStream.serviceId, - ) - - sender.sendMessage(message) - } - } - } - } - - @Suppress("detekt.LongParameterList") - private suspend fun collectAndSendOutgoingStream( - mutex: Mutex, - serialFormat: SerialFormat, - flow: Flow<*>, - serviceTypeString: String, - outgoingStream: KrpcStreamCall, - ) { - flow.collect { - // because we can send new message for the new flow, - // which is not published with `transport.send(message)` - mutex.withLock { - val message = when (serialFormat) { - is StringFormat -> { - val stringData = serialFormat.encodeToString(outgoingStream.elementSerializer, it) - KrpcCallMessage.StreamMessageString( - callId = outgoingStream.callId, - serviceType = serviceTypeString, - streamId = outgoingStream.streamId, - data = stringData, - connectionId = outgoingStream.connectionId, - serviceId = outgoingStream.serviceId, - ) - } - - is BinaryFormat -> { - val binaryData = serialFormat.encodeToByteArray(outgoingStream.elementSerializer, it) - KrpcCallMessage.StreamMessageBinary( - callId = outgoingStream.callId, - serviceType = serviceTypeString, - streamId = outgoingStream.streamId, - data = binaryData, - connectionId = outgoingStream.connectionId, - serviceId = outgoingStream.serviceId, - ) - } - - else -> { - unsupportedSerialFormatError(serialFormat) - } - } - - sender.sendMessage(message) - } - } - } - - protected fun prepareSerialFormat(rpcFlowContext: LazyKrpcStreamContext): SerialFormat { - val module = SerializersModule { - contextual(Flow::class) { - @Suppress("UNCHECKED_CAST") - StreamSerializer.Flow(rpcFlowContext.initialize(), it.first() as KSerializer) - } - - contextual(SharedFlow::class) { - @Suppress("UNCHECKED_CAST") - StreamSerializer.SharedFlow(rpcFlowContext.initialize(), it.first() as KSerializer) - } - - contextual(StateFlow::class) { - @Suppress("UNCHECKED_CAST") - StreamSerializer.StateFlow(rpcFlowContext.initialize(), it.first() as KSerializer) - } - } - - return config.serialFormatInitializer.applySerializersModuleAndBuild(module) - } - - protected fun RpcCallable<*>.toMessageCallType(): KrpcCallMessage.CallType { - return when (invokator) { - is RpcInvokator.Method -> KrpcCallMessage.CallType.Method - is RpcInvokator.Field -> KrpcCallMessage.CallType.Field - } - } -} diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcStreamContext.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcStreamContext.kt deleted file mode 100644 index 73fbd3a24..000000000 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/KrpcStreamContext.kt +++ /dev/null @@ -1,355 +0,0 @@ -/* - * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. - */ - -package kotlinx.rpc.krpc.internal - -import kotlinx.atomicfu.atomic -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.flow.FlowCollector -import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.selects.select -import kotlinx.rpc.internal.utils.InternalRpcApi -import kotlinx.rpc.internal.utils.getDeferred -import kotlinx.rpc.internal.utils.getOrNull -import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap -import kotlinx.rpc.internal.utils.set -import kotlinx.rpc.krpc.KrpcConfig -import kotlinx.rpc.krpc.StreamScope -import kotlinx.rpc.krpc.noStreamScopeError -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialFormat -import kotlin.coroutines.CoroutineContext - -@InternalRpcApi -public class LazyKrpcStreamContext( - public val streamScopeOrNull: StreamScope?, - private val fallbackScope: StreamScope? = null, - private val initializer: (StreamScope) -> KrpcStreamContext, -) { - private val deferred = CompletableDeferred() - private val lazyValue by lazy(LazyThreadSafetyMode.SYNCHRONIZED) { - if (streamScopeOrNull == null && (STREAM_SCOPES_ENABLED || fallbackScope == null)) { - noStreamScopeError() - } - - // null pointer is impossible - val streamScope = streamScopeOrNull ?: fallbackScope!! - initializer(streamScope).also { deferred.complete(it) } - } - - public suspend fun awaitInitialized(): KrpcStreamContext = deferred.await() - - public val valueOrNull: KrpcStreamContext? get() = if (deferred.isCompleted) lazyValue else null - - public fun initialize(): KrpcStreamContext = lazyValue -} - -@InternalRpcApi -public class KrpcStreamContext( - private val callId: String, - private val config: KrpcConfig, - private val connectionId: Long?, - private val serviceId: Long?, - public val streamScope: StreamScope, -) { - private companion object { - private const val STREAM_ID_PREFIX = "stream:" - } - private val closed = CompletableDeferred() - - // thread-safe set - private val closedStreams = RpcInternalConcurrentHashMap>() - - @InternalRpcApi - public inline fun launchIf( - condition: KrpcStreamContext.() -> Boolean, - noinline block: suspend CoroutineScope.(KrpcStreamContext) -> Unit, - ) { - if (condition(this)) { - launch(block) - } - } - - public fun launch(block: suspend CoroutineScope.(KrpcStreamContext) -> Unit) { - streamScope.launch(callId) { - block(this@KrpcStreamContext) - } - } - - public fun cancel(message: String, cause: Throwable?): Job? { - return streamScope.cancelRequestScopeById(callId, message, cause) - } - - init { - streamScope.onScopeCompletion(callId) { cause -> - close(cause) - } - } - - private val streamIdCounter = atomic(0L) - - public val incomingHotFlowsAvailable: Boolean get() = incomingHotFlowsInitialized - - public val outgoingStreamsAvailable: Boolean get() = outgoingStreamsInitialized - - private var incomingStreamsInitialized: Boolean = false - private val incomingStreams by lazy { - incomingStreamsInitialized = true - RpcInternalConcurrentHashMap>() - } - - private var incomingChannelsInitialized: Boolean = false - private val incomingChannels by lazy { - incomingChannelsInitialized = true - RpcInternalConcurrentHashMap?>>() - } - - private var outgoingStreamsInitialized: Boolean = false - internal val outgoingStreams: Channel by lazy { - outgoingStreamsInitialized = true - Channel(capacity = Channel.UNLIMITED) - } - - private var incomingHotFlowsInitialized: Boolean = false - internal val incomingHotFlows: Channel> by lazy { - incomingHotFlowsInitialized = true - Channel(Channel.UNLIMITED) - } - - internal fun registerOutgoingStream( - stream: StreamT, - streamKind: StreamKind, - elementSerializer: KSerializer, - ): String { - val id = "$STREAM_ID_PREFIX${streamIdCounter.getAndIncrement()}" - outgoingStreams.trySend( - KrpcStreamCall( - callId = callId, - streamId = id, - stream = stream, - kind = streamKind, - elementSerializer = elementSerializer, - connectionId = connectionId, - serviceId = serviceId, - ) - ) - return id - } - - internal fun prepareIncomingStream( - streamId: String, - streamKind: StreamKind, - stateFlowInitialValue: Any?, - elementSerializer: KSerializer, - ): StreamT { - val incoming: Channel = Channel(Channel.UNLIMITED) - incomingChannels[streamId] = incoming - - val stream = streamOf(streamId, streamKind, stateFlowInitialValue, incoming) - incomingStreams[streamId] = KrpcStreamCall( - callId = callId, - streamId = streamId, - stream = stream, - kind = streamKind, - elementSerializer = elementSerializer, - connectionId = connectionId, - serviceId = serviceId, - ) - return stream - } - - @Suppress("UNCHECKED_CAST") - private fun streamOf( - streamId: String, - streamKind: StreamKind, - stateFlowInitialValue: Any?, - incoming: Channel, - ): StreamT { - suspend fun consumeFlow(collector: FlowCollector, onError: (Throwable) -> Unit) { - fun onClose() { - incoming.cancel() - - closedStreams[streamId] = Unit - incomingChannels.remove(streamId)?.complete(null) - incomingStreams.remove(streamId) - } - - for (message in incoming) { - when (message) { - is StreamCancel -> { - onClose() - onError(message.cause ?: streamCanceled()) - } - - is StreamEnd -> { - onClose() - if (streamKind != StreamKind.Flow) { - onError(streamCanceled()) - } - - return - } - - else -> { - collector.emit(message) - } - } - } - } - - return when (streamKind) { - StreamKind.Flow -> { - flow { - consumeFlow(this) { e -> throw e } - } - } - - StreamKind.SharedFlow -> { - val sharedFlow: MutableSharedFlow = config.sharedFlowBuilder() - - object : RpcIncomingHotFlow(sharedFlow, ::consumeFlow), MutableSharedFlow by sharedFlow { - override suspend fun collect(collector: FlowCollector): Nothing { - super.collect(collector) - } - - override suspend fun emit(value: Any?) { - super.emit(value) - } - }.also { incomingHotFlows.trySend(it) } - } - - StreamKind.StateFlow -> { - val stateFlow = MutableStateFlow(stateFlowInitialValue) - - object : RpcIncomingHotFlow(stateFlow, ::consumeFlow), MutableStateFlow by stateFlow { - override suspend fun collect(collector: FlowCollector): Nothing { - super.collect(collector) - } - - override suspend fun emit(value: Any?) { - super.emit(value) - } - }.also { incomingHotFlows.trySend(it) } - } - } as StreamT - } - - public suspend fun closeStream(message: KrpcCallMessage.StreamFinished) { - incomingChannelOf(message.streamId)?.send(StreamEnd) - } - - public suspend fun cancelStream(message: KrpcCallMessage.StreamCancel) { - incomingChannelOf(message.streamId)?.send(StreamCancel(message.cause.deserialize())) - } - - public suspend fun send(message: KrpcCallMessage.StreamMessage, serialFormat: SerialFormat) { - val info: KrpcStreamCall? = select { - incomingStreams.getDeferred(message.streamId).onAwait { it } - closedStreams.getDeferred(message.streamId).onAwait { null } - closed.onAwait { null } - } - if (info == null) return - val result = decodeMessageData(serialFormat, info.elementSerializer, message) - val channel = incomingChannelOf(message.streamId) - channel?.send(result) - } - - private suspend fun incomingChannelOf(streamId: String): Channel? { - return select { - incomingChannels.getDeferred(streamId).onAwait { it } - closedStreams.getDeferred(streamId).onAwait { null } - closed.onAwait { null } - } - } - - private fun close(cause: Throwable?) { - if (closed.isCompleted) { - return - } - - closed.complete(Unit) - - if (incomingChannelsInitialized) { - for (channel in incomingChannels.values) { - if (!channel.isCompleted) { - continue - } - - @OptIn(ExperimentalCoroutinesApi::class) - channel.getCompleted()?.apply { - trySend(StreamEnd) - - // close for sending, but not for receiving our cancel message, if possible. - close(cause) - } - } - - incomingChannels.clear() - } - - if (incomingStreamsInitialized) { - incomingStreams.values - .mapNotNull { it.getOrNull()?.stream } - .filterIsInstance() - .forEach { stream -> - stream.subscriptionContexts.forEach { - it.cancel(CancellationException("Stream closed", cause)) - } - } - - incomingStreams.clear() - } - - if (outgoingStreamsInitialized) { - outgoingStreams.close() - } - - if (incomingHotFlowsInitialized) { - incomingHotFlows.close() - } - } -} - -private fun streamCanceled() = NoSuchElementException("Stream canceled") - -private object StreamEnd - -private class StreamCancel(val cause: Throwable? = null) - -private abstract class RpcIncomingHotFlow( - private val rawFlow: MutableSharedFlow, - private val consume: suspend (FlowCollector, onError: (Throwable) -> Unit) -> Unit, -) : MutableSharedFlow { - val subscriptionContexts by lazy { mutableSetOf() } - - override suspend fun collect(collector: FlowCollector): Nothing { - val context = currentCoroutineContext() - - if (context.isActive) { - subscriptionContexts.add(context) - - context.job.invokeOnCompletion { - subscriptionContexts.remove(context) - } - } - - try { - rawFlow.collect(collector) - } finally { - subscriptionContexts.remove(context) - } - } - - // value can be ignored, as actual values are coming from the rawFlow - override suspend fun emit(value: Any?) { - consume(rawFlow) { e -> - subscriptionContexts.forEach { it.cancel(CancellationException(e.message, e)) } - - subscriptionContexts.clear() - } - } -} diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/StreamSerializer.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/StreamSerializer.kt index f34e22c77..d15fa14ff 100644 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/StreamSerializer.kt +++ b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/StreamSerializer.kt @@ -1,87 +1,20 @@ /* - * Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.rpc.krpc.internal -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.descriptors.* -import kotlinx.serialization.encoding.Decoder -import kotlinx.serialization.encoding.Encoder +import kotlinx.rpc.internal.utils.InternalRpcApi +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor +import kotlinx.serialization.descriptors.SerialDescriptor -internal sealed class StreamSerializer(private val streamKind: StreamKind) : KSerializer { - companion object { - private const val STREAM_SERIALIZER_NAME_PREFIX = "StreamSerializer" +@Suppress("PropertyName") +@InternalRpcApi +public abstract class StreamSerializer { + protected val STREAM_ID_SERIAL_NAME: String = "streamId" + protected val STREAM_ID_SERIALIZER_NAME: String = "StreamIdSerializer" - private const val STREAM_ID_SERIAL_NAME = "streamId" - private const val STREAM_ID_SERIALIZER_NAME = "StreamIdSerializer" - - private const val STATE_FLOW_INITIAL_VALUE_SERIAL_NAME = "stateFlowInitialValue" - } - - protected abstract val context: KrpcStreamContext - protected abstract val elementType: KSerializer - - protected open fun ClassSerialDescriptorBuilder.descriptorExtension() { } - - protected open fun decodeStateFlowInitialValue(decoder: Decoder): Any? { return null } - - protected open fun encodeStateFlowInitialValue(encoder: Encoder, flow: StreamT) { } - - override val descriptor: SerialDescriptor by lazy { - buildClassSerialDescriptor("$STREAM_SERIALIZER_NAME_PREFIX.${streamKind.name}") { - element(STREAM_ID_SERIAL_NAME, PrimitiveSerialDescriptor(STREAM_ID_SERIALIZER_NAME, PrimitiveKind.STRING)) - descriptorExtension() - } - } - - override fun deserialize(decoder: Decoder): StreamT { - val streamId = decoder.decodeString() - - val stateFlowValue = decodeStateFlowInitialValue(decoder) - - return context.prepareIncomingStream(streamId, streamKind, stateFlowValue, elementType) as StreamT - } - - override fun serialize(encoder: Encoder, value: StreamT) { - val id = context.registerOutgoingStream(value, streamKind, elementType) - - encoder.encodeString(id) - - encodeStateFlowInitialValue(encoder, value) - } - - class Flow( - override val context: KrpcStreamContext, - override val elementType: KSerializer, - ) : StreamSerializer>(StreamKind.Flow) - - class SharedFlow( - override val context: KrpcStreamContext, - override val elementType: KSerializer, - ) : StreamSerializer>(StreamKind.SharedFlow) - - class StateFlow( - override val context: KrpcStreamContext, - override val elementType: KSerializer, - ) : StreamSerializer>(StreamKind.StateFlow) { - override fun ClassSerialDescriptorBuilder.descriptorExtension() { - element(STATE_FLOW_INITIAL_VALUE_SERIAL_NAME, elementType.descriptor, isOptional = true) - } - - @OptIn(ExperimentalSerializationApi::class) - override fun decodeStateFlowInitialValue(decoder: Decoder): Any? { - return decoder.decodeNullableSerializableValue(elementType) - } - - @OptIn(ExperimentalSerializationApi::class) - override fun encodeStateFlowInitialValue(encoder: Encoder, flow: kotlinx.coroutines.flow.StateFlow) { - encoder.encodeNullableSerializableValue(elementType, flow.value) - } - } -} - -internal enum class StreamKind { - Flow, SharedFlow, StateFlow; + protected val streamIdDescriptor: SerialDescriptor = + PrimitiveSerialDescriptor(STREAM_ID_SERIALIZER_NAME, PrimitiveKind.STRING) } diff --git a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/devStreamScope.kt b/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/devStreamScope.kt deleted file mode 100644 index ca89f135a..000000000 --- a/krpc/krpc-core/src/commonMain/kotlin/kotlinx/rpc/krpc/internal/devStreamScope.kt +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. - */ - -package kotlinx.rpc.krpc.internal - -import kotlinx.rpc.internal.utils.InternalRpcApi - -/** - * For legacy internal users ONLY. - * Special dev builds may set this value to `false`. - * - * If the value is `false`, absence of [kotlinx.rpc.krpc.streamScoped] for a call - * is replaced with service's [kotlinx.rpc.krpc.StreamScope] - * obtained via [kotlinx.rpc.krpc.withClientStreamScope]. - */ -@InternalRpcApi -public const val STREAM_SCOPES_ENABLED: Boolean = true diff --git a/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/KrpcServerService.kt b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/KrpcServerService.kt index c7d5532ed..496938266 100644 --- a/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/KrpcServerService.kt +++ b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/KrpcServerService.kt @@ -11,30 +11,24 @@ import kotlinx.rpc.descriptor.RpcInvokator import kotlinx.rpc.descriptor.RpcServiceDescriptor import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap import kotlinx.rpc.krpc.KrpcConfig -import kotlinx.rpc.krpc.callScoped import kotlinx.rpc.krpc.internal.* import kotlinx.rpc.krpc.internal.logging.RpcInternalCommonLogger -import kotlinx.rpc.krpc.streamScopeOrNull -import kotlinx.rpc.krpc.withServerStreamScope import kotlinx.serialization.BinaryFormat import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialFormat import kotlinx.serialization.StringFormat +import kotlinx.serialization.modules.SerializersModule import kotlin.coroutines.CoroutineContext import kotlin.reflect.typeOf internal class KrpcServerService<@Rpc T : Any>( private val service: T, private val descriptor: RpcServiceDescriptor, - override val config: KrpcConfig.Server, + private val config: KrpcConfig.Server, private val connector: KrpcServerConnector, - coroutineContext: CoroutineContext, -) : KrpcServiceHandler(), CoroutineScope { - override val logger = RpcInternalCommonLogger.logger(rpcInternalObjectId(descriptor.fqName)) - override val sender: KrpcMessageSender get() = connector - private val scope: CoroutineScope = this - - override val coroutineContext: CoroutineContext = coroutineContext.withServerStreamScope() + override val coroutineContext: CoroutineContext, +) : CoroutineScope { + private val logger = RpcInternalCommonLogger.logger(rpcInternalObjectId(descriptor.fqName)) private val requestMap = RpcInternalConcurrentHashMap() @@ -71,7 +65,7 @@ internal class KrpcServerService<@Rpc T : Any>( connectionId = message.connectionId, ) - sender.sendMessage(errorMessage) + connector.sendMessage(errorMessage) } } @@ -97,37 +91,24 @@ internal class KrpcServerService<@Rpc T : Any>( is KrpcCallMessage.StreamCancel -> { // if no stream is present, it probably was already canceled - getAndAwaitStreamContext(message) - ?.cancelStream(message) + serverStreamContext.cancelStream(message) } is KrpcCallMessage.StreamFinished -> { // if no stream is present, it probably was already finished - getAndAwaitStreamContext(message) - ?.closeStream(message) + serverStreamContext.closeStream(message) } is KrpcCallMessage.StreamMessage -> { - requestMap[message.callId]?.streamContext?.apply { - awaitInitialized().send(message, prepareSerialFormat(this)) - } ?: error("Invalid request call id: ${message.callId}") + serverStreamContext.send(message, serialFormat) } } } - private suspend fun getAndAwaitStreamContext(message: KrpcCallMessage): KrpcStreamContext? { - return requestMap[message.callId]?.streamContext?.awaitInitialized() - } - @Suppress("detekt.ThrowsCount", "detekt.LongMethod") private fun handleCall(callData: KrpcCallMessage.CallData) { val callId = callData.callId - val streamContext = LazyKrpcStreamContext(streamScopeOrNull(scope)) { - KrpcStreamContext(callId, config, callData.connectionId, callData.serviceId, it) - } - val serialFormat = prepareSerialFormat(streamContext) - val isMethod = when (callData.callType) { KrpcCallMessage.CallType.Method -> true KrpcCallMessage.CallType.Field -> false @@ -148,7 +129,9 @@ internal class KrpcServerService<@Rpc T : Any>( val data = if (isMethod) { val serializerModule = serialFormat.serializersModule val paramsSerializer = serializerModule.rpcSerializerForType(callable.dataType) - decodeMessageData(serialFormat, paramsSerializer, callData) + serverStreamContext.scoped(callId) { + decodeMessageData(serialFormat, paramsSerializer, callData) + } } else { null } @@ -171,9 +154,7 @@ internal class KrpcServerService<@Rpc T : Any>( val value = when (val invokator = callable.invokator) { is RpcInvokator.Method -> { - callScoped(callId) { - invokator.call(service, data) - } + invokator.call(service, data) } is RpcInvokator.Field -> { @@ -216,27 +197,17 @@ internal class KrpcServerService<@Rpc T : Any>( cause = serializedCause, connectionId = callData.connectionId, serviceId = callData.serviceId, - ).also { sender.sendMessage(it) } + ).also { connector.sendMessage(it) } } - if (failure == null) { - streamContext.valueOrNull?.apply { - launchIf({ incomingHotFlowsAvailable }) { - handleIncomingHotFlows(it) - } - - launchIf({ outgoingStreamsAvailable }) { - handleOutgoingStreams(it, serialFormat, descriptor.fqName) - } - } ?: run { - cancelRequest(callId, fromJob = true) - } - } else { + if (failure != null) { cancelRequest(callId, "Server request failed", failure, fromJob = true) + } else { + cancelRequest(callId, fromJob = true) } } - requestMap[callId] = RpcRequest(requestJob, streamContext) + requestMap[callId] = RpcRequest(requestJob, serverStreamContext) requestJob.start() } @@ -275,7 +246,7 @@ internal class KrpcServerService<@Rpc T : Any>( } } - sender.sendMessage(result) + connector.sendMessage(result) } private suspend fun sendFlowMessages( @@ -351,10 +322,11 @@ internal class KrpcServerService<@Rpc T : Any>( cause: Throwable? = null, fromJob: Boolean = false, ) { + serverStreamContext.removeCall(callId, cause) requestMap.remove(callId)?.cancelAndClose(callId, message, cause, fromJob) // acknowledge the cancellation - sender.sendMessage( + connector.sendMessage( KrpcGenericMessage( connectionId = null, pluginParams = mapOf( @@ -366,6 +338,19 @@ internal class KrpcServerService<@Rpc T : Any>( ) } + private val serverStreamContext: ServerStreamContext = ServerStreamContext() + + private val serialFormat: SerialFormat by lazy { + val module = SerializersModule { + contextual(Flow::class) { + @Suppress("UNCHECKED_CAST") + ServerStreamSerializer(serverStreamContext, it.first() as KSerializer) + } + } + + config.serialFormatInitializer.applySerializersModuleAndBuild(module) + } + companion object { // streams in non-suspend server functions are unique in each call, so no separate is in needed // this one is provided as a way to interact with the old code around streams @@ -373,7 +358,7 @@ internal class KrpcServerService<@Rpc T : Any>( } } -internal class RpcRequest(val handlerJob: Job, val streamContext: LazyKrpcStreamContext) { +internal class RpcRequest(val handlerJob: Job, val streamContext: ServerStreamContext) { suspend fun cancelAndClose( callId: String, message: String? = null, @@ -390,13 +375,6 @@ internal class RpcRequest(val handlerJob: Job, val streamContext: LazyKrpcStream handlerJob.join() } - val ctx = streamContext.valueOrNull - if (ctx == null) { - streamContext.streamScopeOrNull - ?.cancelRequestScopeById(callId, message ?: "Scope cancelled", cause) - ?.join() - } else { - ctx.cancel(message ?: "Request cancelled", cause)?.join() - } + streamContext.removeCall(callId, cause) } } diff --git a/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamContext.kt b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamContext.kt new file mode 100644 index 000000000..fd9066900 --- /dev/null +++ b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamContext.kt @@ -0,0 +1,98 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.krpc.server.internal + +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap +import kotlinx.rpc.krpc.internal.KrpcCallMessage +import kotlinx.rpc.krpc.internal.decodeMessageData +import kotlinx.rpc.krpc.internal.deserialize +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialFormat +import kotlin.native.concurrent.ThreadLocal + +internal class ServerStreamContext { + @ThreadLocal + private var currentCallId: String? = null + + fun scoped(callId: String, body: () -> T): T { + try { + currentCallId = callId + return body() + } finally { + currentCallId = null + } + } + + private val streams = RpcInternalConcurrentHashMap>() + + suspend fun send(message: KrpcCallMessage.StreamMessage, serialFormat: SerialFormat) { + val call = streams[message.callId]?.get(message.streamId) ?: return + val data = scoped(message.streamId) { + decodeMessageData(serialFormat, call.elementSerializer, message) + } + call.channel.send(data) + } + + suspend fun cancelStream(message: KrpcCallMessage.StreamCancel) { + streams[message.callId]?.get(message.streamId)?.channel?.send(StreamCancel(message.cause.deserialize())) + } + + suspend fun closeStream(message: KrpcCallMessage.StreamFinished) { + streams[message.callId]?.get(message.streamId)?.channel?.send(StreamEnd) + } + + fun removeCall(callId: String, cause: Throwable?) { + streams.remove(callId)?.values?.forEach { + it.channel.close(cause) + } + } + + fun prepareClientStream(streamId: String, elementKind: KSerializer): Flow { + val callId = currentCallId ?: error("No call id") + + val channel = Channel(Channel.UNLIMITED) + + @Suppress("UNCHECKED_CAST") + val map = streams.computeIfAbsent(callId) { RpcInternalConcurrentHashMap() } + map[streamId] = StreamCall(callId, streamId, channel, elementKind) + + fun onClose() { + channel.cancel() + map.remove(streamId) + } + + val flow = flow { + for (message in channel) { + println("Consumed on server: $message") + when (message) { + is StreamCancel -> { + onClose() + throw message.cause ?: streamCanceled() + } + + is StreamEnd -> { + onClose() + + return@flow + } + + else -> { + emit(message) + } + } + } + } + + return flow + } + + private fun streamCanceled(): Throwable = NoSuchElementException("Stream canceled") +} + +private data class StreamCancel(val cause: Throwable? = null) +private data object StreamEnd diff --git a/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamSerializer.kt b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamSerializer.kt new file mode 100644 index 000000000..be9b292bc --- /dev/null +++ b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/ServerStreamSerializer.kt @@ -0,0 +1,33 @@ +/* + * Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.krpc.server.internal + +import kotlinx.coroutines.flow.Flow +import kotlinx.rpc.krpc.internal.StreamSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder + +internal class ServerStreamSerializer( + val context: ServerStreamContext, + val elementType: KSerializer, +) : KSerializer>, StreamSerializer() { + override val descriptor: SerialDescriptor by lazy { + buildClassSerialDescriptor("ServerStreamSerializer") { + element(STREAM_ID_SERIAL_NAME, streamIdDescriptor) + } + } + + override fun deserialize(decoder: Decoder): Flow<*> { + val streamId = decoder.decodeString() + + return context.prepareClientStream(streamId, elementType) + } + + override fun serialize(encoder: Encoder, value: Flow<*>) { + error("This method must not be called. Please report to the developer.") + } +} diff --git a/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/StreamCall.kt b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/StreamCall.kt new file mode 100644 index 000000000..9ff3fb526 --- /dev/null +++ b/krpc/krpc-server/src/commonMain/kotlin/kotlinx/rpc/krpc/server/internal/StreamCall.kt @@ -0,0 +1,15 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.krpc.server.internal + +import kotlinx.coroutines.channels.Channel +import kotlinx.serialization.KSerializer + +internal data class StreamCall( + val callId: String, + val streamId: String, + val channel: Channel, + val elementSerializer: KSerializer, +) diff --git a/krpc/krpc-test/build.gradle.kts b/krpc/krpc-test/build.gradle.kts index 0ba1b3566..c5160f950 100644 --- a/krpc/krpc-test/build.gradle.kts +++ b/krpc/krpc-test/build.gradle.kts @@ -28,6 +28,9 @@ kotlin { api(projects.krpc.krpcCore) api(projects.krpc.krpcServer) api(projects.krpc.krpcClient) + api(projects.krpc.krpcLogging) + + implementation(libs.coroutines.debug) implementation(projects.krpc.krpcSerialization.krpcSerializationJson) @@ -101,15 +104,3 @@ tasks.register("moveToGold") { } } } - -rpc { - strict { - stateFlow = RpcStrictMode.NONE - sharedFlow = RpcStrictMode.NONE - nestedFlow = RpcStrictMode.NONE - streamScopedFunctions = RpcStrictMode.NONE - suspendingServerStreaming = RpcStrictMode.NONE - notTopLevelServerFlow = RpcStrictMode.NONE - fields = RpcStrictMode.NONE - } -} diff --git a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestService.kt b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestService.kt index d27ab2caa..3babc86ba 100644 --- a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestService.kt +++ b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestService.kt @@ -5,8 +5,6 @@ package kotlinx.rpc.krpc.test import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.SharedFlow -import kotlinx.coroutines.flow.StateFlow import kotlinx.rpc.RemoteService import kotlinx.rpc.annotations.Rpc import kotlinx.serialization.Contextual @@ -73,29 +71,15 @@ interface KrpcTestService : RemoteService { ): String suspend fun incomingStreamSyncCollect(arg1: Flow): Int - suspend fun incomingStreamAsyncCollect(arg1: Flow): Int - suspend fun outgoingStream(): Flow - suspend fun bidirectionalStream(arg1: Flow): Flow - suspend fun echoStream(arg1: Flow): Flow + suspend fun incomingStreamSyncCollectMultiple(arg1: Flow, arg2: Flow, arg3: Flow): Int + fun outgoingStream(): Flow + fun bidirectionalStream(arg1: Flow): Flow + fun echoStream(arg1: Flow): Flow suspend fun streamInDataClass(payloadWithStream: PayloadWithStream): Int - suspend fun streamInStream(payloadWithStream: Flow): Int - suspend fun streamOutDataClass(): PayloadWithStream - suspend fun streamOfStreamsInReturn(): Flow> - suspend fun streamOfPayloadsInReturn(): Flow - - suspend fun streamInDataClassWithStream(payloadWithPayload: PayloadWithPayload): Int - suspend fun streamInStreamWithStream(payloadWithPayload: Flow): Int - suspend fun returnPayloadWithPayload(): PayloadWithPayload - suspend fun returnFlowPayloadWithPayload(): Flow - - suspend fun bidirectionalFlowOfPayloadWithPayload( - payloadWithPayload: Flow - ): Flow - - suspend fun getNInts(n: Int): Flow - suspend fun getNIntsBatched(n: Int): Flow> + fun getNInts(n: Int): Flow + fun getNIntsBatched(n: Int): Flow> suspend fun bytes(byteArray: ByteArray) suspend fun nullableBytes(byteArray: ByteArray?) @@ -107,15 +91,11 @@ interface KrpcTestService : RemoteService { suspend fun nullableInt(v: Int?): Int? suspend fun nullableList(v: List?): List? - suspend fun delayForever(): Flow + fun delayForever(): Flow suspend fun answerToAnything(arg: String): Int suspend fun krpc173() fun unitFlow(): Flow - - suspend fun sharedFlowInFunction(sharedFlow: SharedFlow): StateFlow - - suspend fun stateFlowInFunction(stateFlow: StateFlow): StateFlow } diff --git a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.kt b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.kt index 45b107c46..195cac0e9 100644 --- a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.kt +++ b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.kt @@ -14,10 +14,6 @@ import kotlin.test.assertEquals @OptIn(ExperimentalCoroutinesApi::class) class KrpcTestServiceBackend(override val coroutineContext: CoroutineContext) : KrpcTestService { - companion object { - const val SHARED_FLOW_REPLAY = 5 - } - override fun nonSuspendFlow(): Flow { return flow { repeat(10) { @@ -133,27 +129,23 @@ class KrpcTestServiceBackend(override val coroutineContext: CoroutineContext) : return arg1.count() } - val incomingStreamAsyncCollectLatch = CompletableDeferred() - - @OptIn(DelicateCoroutinesApi::class) - override suspend fun incomingStreamAsyncCollect(arg1: Flow): Int { - @Suppress("detekt.GlobalCoroutineUsage") - GlobalScope.launch { - assertContentEquals(listOf("test1", "test2", "test3"), arg1.toList()) - incomingStreamAsyncCollectLatch.complete(Unit) - } - return 5 + override suspend fun incomingStreamSyncCollectMultiple( + arg1: Flow, + arg2: Flow, + arg3: Flow, + ): Int { + return arg1.count() + arg2.count() + arg3.count() } - override suspend fun outgoingStream(): Flow { + override fun outgoingStream(): Flow { return flow { emit("a"); emit("b"); emit("c") } } - override suspend fun bidirectionalStream(arg1: Flow): Flow { + override fun bidirectionalStream(arg1: Flow): Flow { return arg1.map { it.reversed() } } - override suspend fun echoStream(arg1: Flow): Flow = flow { + override fun echoStream(arg1: Flow): Flow = flow { arg1.collect { emit(it) } @@ -163,55 +155,7 @@ class KrpcTestServiceBackend(override val coroutineContext: CoroutineContext) : return payloadWithStream.payload.length + payloadWithStream.stream.count() } - // necessary for older Kotlin versions - @Suppress("UnnecessaryOptInAnnotation") - @OptIn(FlowPreview::class) - override suspend fun streamInStream(payloadWithStream: Flow): Int { - return payloadWithStream.flatMapConcat { it.stream }.count() - } - - override suspend fun streamOutDataClass(): PayloadWithStream { - return payload() - } - - override suspend fun streamOfStreamsInReturn(): Flow> { - return flow { - emit(flow { emit("a"); emit("b"); emit("c") }) - emit(flow { emit("1"); emit("2"); emit("3") }) - } - } - - override suspend fun streamOfPayloadsInReturn(): Flow { - return payloadStream() - } - - override suspend fun streamInDataClassWithStream(payloadWithPayload: PayloadWithPayload): Int { - assertContentEquals(KrpcTransportTestBase.expectedPayloadWithPayload(10), payloadWithPayload.collect()) - return 5 - } - - override suspend fun streamInStreamWithStream(payloadWithPayload: Flow): Int { - payloadWithPayload.collectIndexed { index, payload -> - assertContentEquals(KrpcTransportTestBase.expectedPayloadWithPayload(index), payload.collect()) - } - return 5 - } - - override suspend fun returnPayloadWithPayload(): PayloadWithPayload { - return payloadWithPayload() - } - - override suspend fun returnFlowPayloadWithPayload(): Flow { - return payloadWithPayloadStream() - } - - override suspend fun bidirectionalFlowOfPayloadWithPayload( - payloadWithPayload: Flow - ): Flow { - return payloadWithPayload - } - - override suspend fun getNInts(n: Int): Flow { + override fun getNInts(n: Int): Flow { return flow { for (it in 1..n) { emit(it) @@ -219,7 +163,7 @@ class KrpcTestServiceBackend(override val coroutineContext: CoroutineContext) : } } - override suspend fun getNIntsBatched(n: Int): Flow> { + override fun getNIntsBatched(n: Int): Flow> { return flow { for (it in (1..n).chunked(1000)) { emit(it) @@ -264,7 +208,7 @@ class KrpcTestServiceBackend(override val coroutineContext: CoroutineContext) : override suspend fun nullableInt(v: Int?): Int? = v override suspend fun nullableList(v: List?): List? = v - override suspend fun delayForever(): Flow = flow { + override fun delayForever(): Flow = flow { emit(true) delay(Int.MAX_VALUE.toLong()) } @@ -287,29 +231,8 @@ class KrpcTestServiceBackend(override val coroutineContext: CoroutineContext) : emit(Unit) } } - - override suspend fun sharedFlowInFunction(sharedFlow: SharedFlow): StateFlow { - val state = MutableStateFlow(-1) - - launch { - assertEquals(listOf(0, 1, 2, 3, 4), sharedFlow.take(5).toList()) - state.emit(1) - } - - return state - } - - override suspend fun stateFlowInFunction(stateFlow: StateFlow): StateFlow { - val state = MutableStateFlow(-1) - assertEquals(-1, stateFlow.value) - - launch { - assertEquals(42, stateFlow.first { it == 42 }) - state.emit(1) - } - - return state - } } internal expect fun runThreadIfPossible(runner: () -> Unit) + +internal expect fun CoroutineScope.debugCoroutines() diff --git a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTransportTestBase.kt b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTransportTestBase.kt index 38ef3c36d..a94819879 100644 --- a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTransportTestBase.kt +++ b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/KrpcTransportTestBase.kt @@ -12,13 +12,11 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.* import kotlinx.coroutines.sync.Semaphore import kotlinx.coroutines.test.runTest -import kotlinx.rpc.awaitFieldInitialization import kotlinx.rpc.krpc.KrpcTransport import kotlinx.rpc.krpc.rpcClientConfig import kotlinx.rpc.krpc.rpcServerConfig import kotlinx.rpc.krpc.serialization.KrpcSerialFormatConfiguration import kotlinx.rpc.krpc.server.KrpcServer -import kotlinx.rpc.krpc.streamScoped import kotlinx.rpc.registerService import kotlinx.rpc.withService import kotlinx.serialization.KSerializer @@ -70,10 +68,6 @@ abstract class KrpcTransportTestBase { private val serverConfig by lazy { rpcServerConfig { - sharedFlowParameters { - replay = KrpcTestServiceBackend.SHARED_FLOW_REPLAY - } - serialization { serializationConfig() } @@ -82,10 +76,6 @@ abstract class KrpcTransportTestBase { private val clientConfig by lazy { rpcClientConfig { - sharedFlowParameters { - replay = KrpcTestServiceBackend.SHARED_FLOW_REPLAY - } - serialization { serializationConfig() } @@ -142,7 +132,6 @@ abstract class KrpcTransportTestBase { expected = List(10) { it * 2 }, actual = client.nonSuspendBidirectional(List(10) { it }.asFlow()).toList(), ) - print(1) } @Test @@ -263,175 +252,59 @@ abstract class KrpcTransportTestBase { @Test fun incomingStreamSyncCollect() = runTest { - val result = streamScoped { - client.incomingStreamSyncCollect(flowOf("test1", "test2", "test3")) - } + val result = client.incomingStreamSyncCollect(flowOf("test1", "test2", "test3")) assertEquals(3, result) } @Test - fun incomingStreamAsyncCollect() = runTest { - val result = streamScoped { - client.incomingStreamAsyncCollect(flowOf("test1", "test2", "test3")).also { - server.incomingStreamAsyncCollectLatch.await() - } - } + fun incomingStreamSyncCollectMultiple() = runTest { + val result = client.incomingStreamSyncCollectMultiple( + flowOf("test1", "test2", "test3"), + flowOf("test1", "test2", "test3"), + flowOf("test1", "test2", "test3"), + ) - assertEquals(5, result) + assertEquals(9, result) } @Test fun outgoingStream() = runTest { - streamScoped { - val result = client.outgoingStream() - assertEquals(listOf("a", "b", "c"), result.toList(mutableListOf())) - } + val result = client.outgoingStream() + assertEquals(listOf("a", "b", "c"), result.toList(mutableListOf())) } @Test fun bidirectionalStream() = runTest { - streamScoped { - val result = client.bidirectionalStream(flowOf("test1", "test2", "test3")) - assertEquals( - listOf("test1".reversed(), "test2".reversed(), "test3".reversed()), - result.toList(mutableListOf()), - ) - } + val result = client.bidirectionalStream(flowOf("test1", "test2", "test3")) + assertEquals( + listOf("test1".reversed(), "test2".reversed(), "test3".reversed()), + result.toList(mutableListOf()), + ) } @Test fun streamInDataClass() = runTest { - streamScoped { - val result = client.streamInDataClass(payload()) - assertEquals(8, result) - } + val result = client.streamInDataClass(payload()) + assertEquals(8, result) } @Test - fun streamInStream() = runTest { - streamScoped { - val result = client.streamInStream(payloadStream()) - assertEquals(30, result) - } - } - - @Test - fun streamOutDataClass() = runTest { - streamScoped { - val result = client.streamOutDataClass() - assertEquals("test0", result.payload) - assertEquals(listOf("a0", "b0", "c0"), result.stream.toList(mutableListOf())) - } - } - - @Test - fun streamOfStreamsInReturn() = runTest { - streamScoped { - val result = client.streamOfStreamsInReturn().map { - it.toList(mutableListOf()) - }.toList(mutableListOf()) - assertEquals(listOf(listOf("a", "b", "c"), listOf("1", "2", "3")), result) - } - } - - @Test - fun streamOfPayloadsInReturn() = runTest { - streamScoped { - val result = client.streamOfPayloadsInReturn().map { - it.stream.toList(mutableListOf()).joinToString() - }.toList(mutableListOf()).joinToString() - assertEquals( - "a0, b0, c0, a1, b1, c1, a2, b2, c2, a3, b3, c3, a4, " + - "b4, c4, a5, b5, c5, a6, b6, c6, a7, b7, c7, a8, b8, c8, a9, b9, c9", - result, - ) - } - } - - @Test - fun streamInDataClassWithStream() = runTest { - streamScoped { - val result = client.streamInDataClassWithStream(payloadWithPayload()) - assertEquals(5, result) - } - } - - @Test - fun streamInStreamWithStream() = runTest { - val result = streamScoped { - client.streamInStreamWithStream(payloadWithPayloadStream()) - } - assertEquals(5, result) - } - - @Test - fun returnPayloadWithPayload() = runTest { - streamScoped { - assertContentEquals(expectedPayloadWithPayload(10), client.returnPayloadWithPayload().collect()) - } - } - - @Test - fun returnFlowPayloadWithPayload() = runTest { - streamScoped { - client.returnFlowPayloadWithPayload().collectIndexed { index, payloadWithPayload -> - assertContentEquals(expectedPayloadWithPayload(index), payloadWithPayload.collect()) - } - } - } - - @Test - fun bidirectionalFlowOfPayloadWithPayload() = runTest { - streamScoped { - val result = client.bidirectionalFlowOfPayloadWithPayload( - flow { - repeat(5) { - emit(payloadWithPayload(10)) - } - }, - ) - - val all = result.toList().onEach { - assertContentEquals(expectedPayloadWithPayload(10), it.collect()) - }.size - - assertEquals(5, all) - } - } - - @Test - fun bidirectionalAsyncStream() = runTest { - streamScoped { - val flow = MutableSharedFlow(1) - val result = client.echoStream(flow.take(10)) - launch { - var id = 0 - result.collect { - assertEquals(id, it) - id++ - flow.emit(id) - } - } - - flow.emit(0) - } + fun bidirectionalEchoStream() = runTest { + val result = client.echoStream(flowOf(1, 2, 3)).toList().sum() + assertEquals(6, result) } @Test fun `RPC should be able to receive 100_000 ints in reasonable time`() = runTest { - streamScoped { - val n = 100_000 - assertEquals(client.getNInts(n).last(), n) - } + val n = 100_000 + assertEquals(client.getNInts(n).last(), n) } @Test fun `RPC should be able to receive 100_000 ints with batching in reasonable time`() = runTest { - streamScoped { - val n = 100_000 - assertEquals(client.getNIntsBatched(n).last().last(), n) - } + val n = 100_000 + assertEquals(client.getNIntsBatched(n).last().last(), n) } @Test @@ -447,7 +320,7 @@ abstract class KrpcTransportTestBase { try { client.throwsIllegalArgument("me") fail("Exception expected: throwsIllegalArgument") - } catch (e : AssertionError) { + } catch (e: AssertionError) { throw e } catch (e: Throwable) { assertEquals("me", e.message) @@ -455,7 +328,7 @@ abstract class KrpcTransportTestBase { try { client.throwsSerializableWithMessageAndCause("me") fail("Exception expected: throwsSerializableWithMessageAndCause") - } catch (e : AssertionError) { + } catch (e: AssertionError) { throw e } catch (e: Throwable) { assertEquals("me", e.message) @@ -464,7 +337,7 @@ abstract class KrpcTransportTestBase { try { client.throwsThrowable("me") fail("Exception expected: throwsThrowable") - } catch (e : AssertionError) { + } catch (e: AssertionError) { throw e } catch (e: Throwable) { assertEquals("me", e.message) @@ -472,7 +345,7 @@ abstract class KrpcTransportTestBase { try { client.throwsUNSTOPPABLEThrowable("me") fail("Exception expected: throwsUNSTOPPABLEThrowable") - } catch (e : AssertionError) { + } catch (e: AssertionError) { throw e } catch (e: Throwable) { assertEquals("me", e.message) @@ -498,10 +371,8 @@ abstract class KrpcTransportTestBase { val flag: Channel = Channel() val remote = launch { try { - streamScoped { - client.delayForever().collect { - flag.send(it) - } + client.delayForever().collect { + flag.send(it) } } catch (e: CancellationException) { throw e @@ -573,34 +444,6 @@ abstract class KrpcTransportTestBase { fun testUnitFlow() = runTest { assertEquals(Unit, client.unitFlow().toList().single()) } - - @Test - fun testSharedFlowInFunction() = runTest { - streamScoped { - val flow = sharedFlowOfT { it } - - val state = client.sharedFlowInFunction(flow) - - assertEquals(1, state.first { it == 1 }) - } - } - - @Test - fun testStateFlowInFunction() = runTest { - streamScoped { - val flow = stateFlowOfT { it } - - val state = client.stateFlowInFunction(flow) - - flow.emit(42) - - assertEquals(1, state.first { it == 1 }) - } - } - - companion object { - fun expectedPayloadWithPayload(size: Int) = List(size) { listOf("a$it", "b$it", "c$it") } - } } internal expect val isJs: Boolean diff --git a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/Payloads.kt b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/Payloads.kt index 2523a6f26..7a077316b 100644 --- a/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/Payloads.kt +++ b/krpc/krpc-test/src/commonMain/kotlin/kotlinx/rpc/krpc/test/Payloads.kt @@ -4,29 +4,14 @@ package kotlinx.rpc.krpc.test -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.launch import kotlinx.serialization.Contextual import kotlinx.serialization.Serializable @Serializable data class PayloadWithStream(val payload: String, val stream: @Contextual Flow) -@Serializable -data class PayloadWithPayload(val payload: String, val flow: @Contextual Flow) { - suspend fun collect(): List> { - return flow.map { - it.stream.toList() - }.toList() - } -} - fun payload(index: Int = 0): PayloadWithStream { return PayloadWithStream( "test$index", @@ -34,48 +19,6 @@ fun payload(index: Int = 0): PayloadWithStream { ) } -fun payloadWithPayload(index: Int = 10): PayloadWithPayload { - return PayloadWithPayload("test$index", payloadStream(index)) -} - -fun payloadStream(count: Int = 10): Flow { - return flow { - repeat(count) { - emit(payload(it)) - } - } -} - -fun payloadWithPayloadStream(count: Int = 10): Flow { - return flow { - repeat(count) { - emit(payloadWithPayload(it)) - } - } -} - fun plainFlow(count: Int = 5, get: (Int) -> T): Flow { return flow { repeat(count) { emit(get(it)) } } } - -private fun > CoroutineScope.runSharedFlow( - flow: FlowT, - count: Int = KrpcTestServiceBackend.SHARED_FLOW_REPLAY, - getter: (Int) -> T, -) = apply { - launch { - repeat(count) { - flow.emit(getter(it)) - } - } -} - -fun CoroutineScope.sharedFlowOfT(getter: (Int) -> T): MutableSharedFlow { - return MutableSharedFlow(KrpcTestServiceBackend.SHARED_FLOW_REPLAY).also { flow -> - runSharedFlow(flow) { getter(it) } - } -} - -fun stateFlowOfT(getter: (Int) -> T): MutableStateFlow { - return MutableStateFlow(getter(-1)) -} diff --git a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/SamplingService.kt b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/SamplingService.kt index 738e5d371..36e9f9ff8 100644 --- a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/SamplingService.kt +++ b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/SamplingService.kt @@ -6,14 +6,10 @@ @file:Suppress("PackageDirectoryMismatch") package org.jetbrains.krpc.test.api.util -import kotlinx.coroutines.delay import kotlinx.coroutines.flow.* -import kotlinx.coroutines.launch import kotlinx.rpc.RemoteService import kotlinx.rpc.annotations.Rpc import kotlinx.rpc.krpc.test.plainFlow -import kotlinx.rpc.krpc.test.sharedFlowOfT -import kotlinx.rpc.krpc.test.stateFlowOfT import kotlinx.serialization.Serializable import kotlin.coroutines.CoroutineContext @@ -28,11 +24,7 @@ interface SamplingService : RemoteService { suspend fun clientStream(flow: Flow): List - suspend fun clientNestedStream(flow: Flow>): List> - - suspend fun serverFlow(): Flow - - suspend fun serverNestedFlow(): Flow> + fun serverFlow(): Flow suspend fun callException() } @@ -46,18 +38,10 @@ class SamplingServiceImpl(override val coroutineContext: CoroutineContext) : Sam return flow.toList() } - override suspend fun clientNestedStream(flow: Flow>): List> { - return flow.map { it.toList() }.toList() - } - - override suspend fun serverFlow(): Flow { + override fun serverFlow(): Flow { return plainFlow { SamplingData("data") } } - override suspend fun serverNestedFlow(): Flow> { - return plainFlow { plainFlow { it } } - } - override suspend fun callException() { error("Server exception") } diff --git a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt index c85f45d91..e1c2e17f3 100644 --- a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt +++ b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/TransportTest.kt @@ -159,25 +159,27 @@ class TransportTest { @Test fun testLateConnectWithManyServices() = runTest { - val transports = LocalTransport() + repeat(100) { + val transports = LocalTransport() - val client = clientOf(transports) + val client = clientOf(transports) - val result = List(10) { - async { - val service = client.withService() - service.echo("foo") + val result = List(10) { + async { + val service = client.withService() + service.echo("foo") + } } - } - val server = serverOf(transports) - val echoServices = server.registerServiceAndReturn { EchoImpl(it) } + val server = serverOf(transports) + val echoServices = server.registerServiceAndReturn { EchoImpl(it) } - val response = result.awaitAll() - assertTrue { response.all { it == "foo" } } - assertEquals(10, echoServices.sumOf { it.received.value }) + val response = result.awaitAll() + assertTrue { response.all { it == "foo" } } + assertEquals(10, echoServices.sumOf { it.received.value }) - server.cancel() + server.cancel() + } } diff --git a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationService.kt b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationService.kt index 2dd950b06..a44d3f160 100644 --- a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationService.kt +++ b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationService.kt @@ -9,10 +9,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlinx.rpc.RemoteService import kotlinx.rpc.annotations.Rpc -import kotlinx.rpc.krpc.invokeOnStreamScopeCompletion import kotlin.coroutines.CoroutineContext -import kotlin.properties.Delegates -import kotlin.test.assertIs @Rpc interface CancellationService : RemoteService { @@ -22,7 +19,7 @@ interface CancellationService : RemoteService { suspend fun callException() - suspend fun incomingStream(): Flow + fun incomingStream(): Flow suspend fun outgoingStream(stream: Flow) @@ -30,14 +27,6 @@ interface CancellationService : RemoteService { suspend fun outgoingStreamWithException(stream: Flow) - suspend fun outgoingHotFlow(stream: StateFlow) - - suspend fun incomingHotFlow(): StateFlow - - suspend fun closedStreamScopeCallback() - - suspend fun closedStreamScopeCallbackWithStream(): Flow - fun nonSuspendable(): Flow } @@ -62,7 +51,7 @@ class CancellationServiceImpl(override val coroutineContext: CoroutineContext) : error("callException") } - override suspend fun incomingStream(): Flow { + override fun incomingStream(): Flow { return resumableFlow(fence) } @@ -85,66 +74,6 @@ class CancellationServiceImpl(override val coroutineContext: CoroutineContext) : error("exception in request") } - val hotFlowMirror = MutableStateFlow(-1) - val hotFlowConsumedSize = CompletableDeferred() - - override suspend fun outgoingHotFlow(stream: StateFlow) { - launch { - var cnt = 0 - - val cancellation = runCatching { - stream.collect { - cnt++ - hotFlowMirror.emit(it) - } - } - - val result = runCatching { - assertIs(cancellation.exceptionOrNull(), "Cancellation should be thrown") - - cnt - } - - hotFlowConsumedSize.completeWith(result) - } - } - - var incomingHotFlowJob by Delegates.notNull() - - override suspend fun incomingHotFlow(): StateFlow { - val state = MutableStateFlow(-1) - - incomingHotFlowJob = launch { - repeat(Int.MAX_VALUE) { value -> - state.value = value - - hotFlowMirror.first { it == value } - } - } - - invokeOnStreamScopeCompletion { - incomingHotFlowJob.cancel() - } - - return state - } - - val streamScopeCallbackResult = CompletableDeferred() - - override suspend fun closedStreamScopeCallback() { - invokeOnStreamScopeCompletion { cause -> - streamScopeCallbackResult.complete(cause) - } - } - - override suspend fun closedStreamScopeCallbackWithStream(): Flow { - invokeOnStreamScopeCompletion { cause -> - streamScopeCallbackResult.complete(cause) - } - - return resumableFlow(fence) - } - private fun consume(stream: Flow) { launch { try { diff --git a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt index 00f82b701..53372c73d 100644 --- a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt +++ b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationTest.kt @@ -5,15 +5,8 @@ package kotlinx.rpc.krpc.test.cancellation import kotlinx.coroutines.* -import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.toList -import kotlinx.coroutines.test.runTest -import kotlinx.rpc.krpc.StreamScope -import kotlinx.rpc.krpc.internal.STREAM_SCOPES_ENABLED -import kotlinx.rpc.krpc.invokeOnStreamScopeCompletion -import kotlinx.rpc.krpc.streamScoped -import kotlinx.rpc.krpc.withStreamScope import kotlinx.rpc.withService import kotlin.test.* @@ -191,10 +184,8 @@ class CancellationTest { fun testStreamScopeOutgoing() = runCancellationTest { val fence = CompletableDeferred() - streamScoped { - service.outgoingStream(resumableFlow(fence)) - serverInstance().firstIncomingConsumed.await() - } + service.outgoingStream(resumableFlow(fence)) + serverInstance().firstIncomingConsumed.await() fence.complete(Unit) serverInstance().consumedAll.await() @@ -203,40 +194,10 @@ class CancellationTest { stopAllAndJoin() } - @Test - fun testStreamScopeAbsentForOutgoingStream() = runCancellationTest { - val fence = CompletableDeferred() - - if (STREAM_SCOPES_ENABLED) { - assertFailsWith { - service.outgoingStream(resumableFlow(fence)) - } - } else { - service.outgoingStream(resumableFlow(fence)) - } - - stopAllAndJoin() - } - - @Test - fun testStreamScopeAbsentForIncomingStream() = runCancellationTest { - if (STREAM_SCOPES_ENABLED) { - assertFailsWith { - service.incomingStream() - } - } else { - service.incomingStream() - } - - stopAllAndJoin() - } - @Test fun testStreamScopeIncoming() = runCancellationTest { val first: Int - val flow = streamScoped { - service.incomingStream().apply { first = first() } - } + val flow = service.incomingStream().apply { first = first() } serverInstance().fence.complete(Unit) val consumed = flow.toList() @@ -252,11 +213,9 @@ class CancellationTest { val fence = CompletableDeferred() runCatching { - streamScoped { - service.outgoingStream(resumableFlow(fence)) - serverInstance().firstIncomingConsumed.await() - error("exception in stream scope") - } + service.outgoingStream(resumableFlow(fence)) + serverInstance().firstIncomingConsumed.await() + error("exception in stream scope") } fence.complete(Unit) @@ -271,43 +230,33 @@ class CancellationTest { fun testExceptionInRequest() = runCancellationTest { val fence = CompletableDeferred() - streamScoped { - runCatching { - service.outgoingStreamWithException(resumableFlow(fence)) - } - - // to be sure that exception canceled the stream and not scope closure - serverInstance().consumedAll.await() + runCatching { + service.outgoingStreamWithException(resumableFlow(fence)) } + // to be sure that exception canceled the stream and not scope closure + serverInstance().consumedAll.await() + assertContentEquals(listOf(0), serverInstance().consumedIncomingValues) stopAllAndJoin() } - @Test - fun testNestedStreamScopesForbidden() = runTest { - assertFailsWith { - streamScoped { streamScoped { } } - } - } - @Test fun testExceptionInRequestDoesNotCancelOtherRequests() = runCancellationTest { val fence = CompletableDeferred() - val result = streamScoped { - val flow = service.incomingStream() - - runCatching { - service.outgoingStreamWithException(resumableFlow(fence)) - } - fence.complete(Unit) - serverInstance().fence.complete(Unit) + val flow = service.incomingStream() - flow.toList() + runCatching { + service.outgoingStreamWithException(resumableFlow(fence)) } + fence.complete(Unit) + serverInstance().fence.complete(Unit) + + val result = flow.toList() + serverInstance().consumedAll.await() assertContentEquals(listOf(0), serverInstance().consumedIncomingValues) @@ -319,20 +268,19 @@ class CancellationTest { @Test fun testRequestCancellationCancelsStream() = runCancellationTest { val fence = CompletableDeferred() - streamScoped { - val job = launch { - service.outgoingStreamWithDelayedResponse(resumableFlow(fence)) - } - serverInstance().firstIncomingConsumed.await() + val job = launch { + service.outgoingStreamWithDelayedResponse(resumableFlow(fence)) + } - job.cancel("Test request cancelled") - job.join() - assertTrue("Job must be canceled") { job.isCancelled } + serverInstance().firstIncomingConsumed.await() - // close by request cancel and not scope closure - serverInstance().consumedAll.await() - } + job.cancel("Test request cancelled") + job.join() + assertTrue("Job must be canceled") { job.isCancelled } + + // close by request cancel and not scope closure + serverInstance().consumedAll.await() assertContentEquals(listOf(0), serverInstance().consumedIncomingValues) @@ -342,25 +290,23 @@ class CancellationTest { @Test fun testRequestCancellationCancelsStreamButNotOthers() = runCancellationTest { val fence = CompletableDeferred() - val result = streamScoped { - val job = launch { - service.outgoingStreamWithDelayedResponse(resumableFlow(fence)) - } + val job = launch { + service.outgoingStreamWithDelayedResponse(resumableFlow(fence)) + } - val flow = service.incomingStream() + val flow = service.incomingStream() - serverInstance().firstIncomingConsumed.await() + serverInstance().firstIncomingConsumed.await() - job.cancel("Test request cancelled") - job.join() - assertTrue("Job must be canceled") { job.isCancelled } - serverInstance().fence.complete(Unit) + job.cancel("Test request cancelled") + job.join() + assertTrue("Job must be canceled") { job.isCancelled } + serverInstance().fence.complete(Unit) - // close by request cancel and not scope closure - serverInstance().consumedAll.await() + // close by request cancel and not scope closure + serverInstance().consumedAll.await() - flow.toList() - } + val result = flow.toList() assertContentEquals(listOf(0), serverInstance().consumedIncomingValues) assertContentEquals(List(2) { it }, result) @@ -371,18 +317,16 @@ class CancellationTest { @Test fun testServiceCancellationCancelsStream() = runCancellationTest { val fence = CompletableDeferred() - streamScoped { - launch { - service.outgoingStream(resumableFlow(fence)) - } + launch { + service.outgoingStream(resumableFlow(fence)) + } - serverInstance().firstIncomingConsumed.await() + serverInstance().firstIncomingConsumed.await() - service.cancel("Test request cancelled") - service.join() + service.cancel("Test request cancelled") + service.join() - serverInstance().consumedAll.await() - } + serverInstance().consumedAll.await() assertContentEquals(listOf(0), serverInstance().consumedIncomingValues) @@ -392,24 +336,22 @@ class CancellationTest { @Test fun testServiceCancellationCancelsStreamButNotOthers() = runCancellationTest { val fence = CompletableDeferred() - val secondServiceResult = streamScoped { - launch { - service.outgoingStream(resumableFlow(fence)) - } + launch { + service.outgoingStream(resumableFlow(fence)) + } - serverInstance().firstIncomingConsumed.await() + serverInstance().firstIncomingConsumed.await() - val secondServiceFlow = client - .withService() - .incomingStream() + val secondServiceFlow = client + .withService() + .incomingStream() - service.cancel("Test request cancelled") - service.join() + service.cancel("Test request cancelled") + service.join() - serverInstances[1].fence.complete(Unit) + serverInstances[1].fence.complete(Unit) - secondServiceFlow.toList() - } + val secondServiceResult = secondServiceFlow.toList() serverInstance().consumedAll.await() @@ -422,21 +364,20 @@ class CancellationTest { @Test fun testScopeClosureCancelsAllStreams() = runCancellationTest { val fence = CompletableDeferred() - streamScoped { - service.outgoingStream(resumableFlow(fence)) - client.withService().outgoingStream(resumableFlow(fence)) + service.outgoingStream(resumableFlow(fence)) - serverInstance().firstIncomingConsumed.await() + client.withService().outgoingStream(resumableFlow(fence)) - while (true) { - if (serverInstances.size == 2) { - serverInstances[1].firstIncomingConsumed.await() - break - } + serverInstance().firstIncomingConsumed.await() - unskippableDelay(50) + while (true) { + if (serverInstances.size == 2) { + serverInstances[1].firstIncomingConsumed.await() + break } + + unskippableDelay(50) } serverInstances.forEach { it.consumedAll.await() } @@ -447,153 +388,15 @@ class CancellationTest { stopAllAndJoin() } - @Test - fun testInvokeOnStreamScopeCompletionOnServerWithNoStreams() = runCancellationTest { - streamScoped { - service.closedStreamScopeCallback() - } - - serverInstance().streamScopeCallbackResult.await() - - stopAllAndJoin() - } - - @Test - fun testInvokeOnStreamScopeCompletionOnServer() = runCancellationTest { - val result = streamScoped { - service.closedStreamScopeCallbackWithStream().also { - serverInstance().fence.complete(Unit) - }.toList() - } - - serverInstance().streamScopeCallbackResult.await() - - assertContentEquals(List(2) { it }, result) - - stopAllAndJoin() - } - - @Test - fun testInvokeOnStreamScopeCompletionOnClient() = runCancellationTest { - val streamScopeCompleted = CompletableDeferred() - - streamScoped { - service.closedStreamScopeCallback() - - invokeOnStreamScopeCompletion { - streamScopeCompleted.complete(Unit) - } - } - - streamScopeCompleted.await() - - stopAllAndJoin() - } - - @Test - fun testOutgoingHotFlow() = runCancellationTest { - streamScoped { - val state = MutableStateFlow(-2) - - service.outgoingHotFlow(state) - - val mirror = serverInstance().hotFlowMirror - mirror.first { it == -2 } // initial value - - repeat(3) { value -> - state.value = value - mirror.first { it == value } - } - } - - assertEquals(4, serverInstance().hotFlowConsumedSize.await()) - - stopAllAndJoin() - } - - @Test - fun testIncomingHotFlow() = runCancellationTest { - val state = streamScoped { - val state = service.incomingHotFlow() - - val mirror = serverInstance().hotFlowMirror - repeat(3) { value -> - state.first { it == value } - mirror.value = value - } - - state.first { it == 3 } - - state - } - - serverInstance().incomingHotFlowJob.join() - assertEquals(3, state.value) - assertEquals(2, serverInstance().hotFlowMirror.value) - - stopAllAndJoin() - } - @Test fun testCancelledClientCancelsFlows() = runCancellationTest { - streamScoped { - val flow = service.incomingStream() - - assertEquals(0, flow.first()) - client.cancel() - val rest = flow.toList() + val flow = service.incomingStream() - assertTrue("Rest must be empty, as flow was closed") { rest.isEmpty() } - } - - stopAllAndJoin() - } - - @Test - fun manualStreamScopeNoCancel() = runCancellationTest { - val myJob = Job() - val streamScope = StreamScope(myJob) - - val unrelatedJob = Job() - - var first: Int = -1 - val deferredFlow = CoroutineScope(unrelatedJob).async { - withStreamScope(streamScope) { - service.incomingStream().apply { first = first() } - } - } - val flow= deferredFlow.await() - - serverInstance().fence.complete(Unit) - val consumed = flow.toList() - - assertEquals(0, first) - assertContentEquals(listOf(1), consumed) - - stopAllAndJoin() - } - - @Test - fun manualStreamScopeWithCancel() = runCancellationTest { - val myJob = Job() - val streamScope = StreamScope(myJob) - - val unrelatedJob = Job() - - var first: Int = -1 - val deferredFlow = CoroutineScope(unrelatedJob).async { - withStreamScope(streamScope) { - service.incomingStream().apply { first = first() } - } - } - val flow= deferredFlow.await() - - streamScope.close() - serverInstance().fence.complete(Unit) - val consumed = flow.toList() + assertEquals(0, flow.first()) + client.cancel() + val rest = flow.toList() - assertEquals(0, first) - assertContentEquals(emptyList(), consumed) + assertTrue("Rest must be empty, as flow was closed") { rest.isEmpty() } stopAllAndJoin() } @@ -678,7 +481,7 @@ class CancellationTest { private fun CancellationToolkit.processFlowAndLeaveUnusedForGC( firstDone: CompletableDeferred, - latch: CompletableDeferred + latch: CompletableDeferred, ): Job { val flow = service.nonSuspendable() val requestJob = launch { diff --git a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationToolkit.kt b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationToolkit.kt index 1614ff10d..626853767 100644 --- a/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationToolkit.kt +++ b/krpc/krpc-test/src/commonTest/kotlin/kotlinx/rpc/krpc/test/cancellation/CancellationToolkit.kt @@ -16,7 +16,6 @@ import kotlinx.rpc.krpc.rpcServerConfig import kotlinx.rpc.krpc.serialization.json.json import kotlinx.rpc.krpc.test.KrpcTestClient import kotlinx.rpc.krpc.test.KrpcTestServer -import kotlinx.rpc.krpc.test.KrpcTestServiceBackend import kotlinx.rpc.krpc.test.LocalTransport import kotlinx.rpc.registerService import kotlinx.rpc.withService @@ -51,10 +50,6 @@ class CancellationToolkit(scope: CoroutineScope) : CoroutineScope by scope { val client by lazy { KrpcTestClient(rpcClientConfig { serializationConfig() - - sharedFlowParameters { - replay = KrpcTestServiceBackend.SHARED_FLOW_REPLAY - } }, transport.client) } diff --git a/krpc/krpc-test/src/jsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.js.kt b/krpc/krpc-test/src/jsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.js.kt index 5a01eb69d..79eaba187 100644 --- a/krpc/krpc-test/src/jsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.js.kt +++ b/krpc/krpc-test/src/jsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.js.kt @@ -4,6 +4,11 @@ package kotlinx.rpc.krpc.test +import kotlinx.coroutines.CoroutineScope + actual inline fun runThreadIfPossible(runner: () -> Unit) { runner() } + +internal actual fun CoroutineScope.debugCoroutines() { +} diff --git a/krpc/krpc-test/src/jvmMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.jvm.kt b/krpc/krpc-test/src/jvmMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.jvm.kt index 9142c37fa..f385e52f0 100644 --- a/krpc/krpc-test/src/jvmMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.jvm.kt +++ b/krpc/krpc-test/src/jvmMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.jvm.kt @@ -4,6 +4,26 @@ package kotlinx.rpc.krpc.test +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.debug.DebugProbes +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlin.time.Duration.Companion.seconds + actual fun runThreadIfPossible(runner: () -> Unit) { Thread(runner).start() } + +@OptIn(ExperimentalCoroutinesApi::class) +internal actual fun CoroutineScope.debugCoroutines() { + DebugProbes.install() + launch { + withContext(Dispatchers.IO) { + delay(10.seconds) + } + DebugProbes.dumpCoroutines() + } +} diff --git a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/ApiVersioningTest.kt b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/ApiVersioningTest.kt index 3b9742d41..fd16c9ecf 100644 --- a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/ApiVersioningTest.kt +++ b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/ApiVersioningTest.kt @@ -4,7 +4,6 @@ package kotlinx.rpc.krpc.test.api -import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.rpc.krpc.internal.CancellationType import kotlinx.rpc.krpc.internal.KrpcMessage @@ -65,17 +64,6 @@ class ApiVersioningTest { } } - @Test - @Ignore("Nested flows proved to be unstable") - fun testClientNestedStreamSampling() = wireSamplingTest("clientNestedStream") { - sample { - val response = clientNestedStream(plainFlow { plainFlow { it } }).join() - val expected = List(5) { List(5) { it } }.join() - - assertEquals(expected, response) - } - } - @Test @Ignore("Flow sampling tests are too unstable. Ignored until better fix") fun testServerStreamSampling() = wireSamplingTest("serverStream") { @@ -87,17 +75,6 @@ class ApiVersioningTest { } } - @Test - @Ignore("Nested flows proved to be unstable") - fun testServerNestedStreamSampling() = wireSamplingTest("serverNestedStream") { - sample { - val response = serverNestedFlow().map { it.toList() }.toList().join() - val expected = List(5) { List(5) { it } }.join() - - assertEquals(expected, response) - } - } - @Test fun testCallExceptionSampling() = wireSamplingTest("callException") { // ignore protobuf here, as it's hard to properly sample stacktrace diff --git a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/WireSamplingTestScope.kt b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/WireSamplingTestScope.kt index 60ab4e205..2ccb266b3 100644 --- a/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/WireSamplingTestScope.kt +++ b/krpc/krpc-test/src/jvmTest/kotlin/kotlinx/rpc/krpc/test/api/WireSamplingTestScope.kt @@ -24,7 +24,6 @@ import kotlinx.rpc.krpc.serialization.json.json import kotlinx.rpc.krpc.serialization.protobuf.protobuf import kotlinx.rpc.krpc.test.KrpcTestClient import kotlinx.rpc.krpc.test.KrpcTestServer -import kotlinx.rpc.krpc.test.KrpcTestServiceBackend import kotlinx.rpc.krpc.test.LocalTransport import kotlinx.rpc.krpc.test.api.ApiVersioningTest.Companion.latestVersionOrCurrent import kotlinx.rpc.krpc.test.api.util.GoldComparable @@ -201,10 +200,6 @@ private class WireToolkit(scope: CoroutineScope, format: SamplingFormat, val log serialization { format.init(this) } - - sharedFlowParameters { - replay = KrpcTestServiceBackend.SHARED_FLOW_REPLAY - } }, transport.client) } diff --git a/krpc/krpc-test/src/nativeMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.native.kt b/krpc/krpc-test/src/nativeMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.native.kt index 72c450a8e..64bbe3ee4 100644 --- a/krpc/krpc-test/src/nativeMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.native.kt +++ b/krpc/krpc-test/src/nativeMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.native.kt @@ -4,6 +4,7 @@ package kotlinx.rpc.krpc.test +import kotlinx.coroutines.CoroutineScope import kotlin.native.concurrent.ObsoleteWorkersApi import kotlin.native.concurrent.Worker @@ -11,3 +12,6 @@ import kotlin.native.concurrent.Worker actual fun runThreadIfPossible(runner: () -> Unit) { Worker.start(errorReporting = true).executeAfter(0L, runner) } + +internal actual fun CoroutineScope.debugCoroutines() { +} diff --git a/krpc/krpc-test/src/wasmJsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.wasmJs.kt b/krpc/krpc-test/src/wasmJsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.wasmJs.kt index 5a01eb69d..79eaba187 100644 --- a/krpc/krpc-test/src/wasmJsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.wasmJs.kt +++ b/krpc/krpc-test/src/wasmJsMain/kotlin/kotlinx/rpc/krpc/test/KrpcTestServiceBackend.wasmJs.kt @@ -4,6 +4,11 @@ package kotlinx.rpc.krpc.test +import kotlinx.coroutines.CoroutineScope + actual inline fun runThreadIfPossible(runner: () -> Unit) { runner() } + +internal actual fun CoroutineScope.debugCoroutines() { +} diff --git a/tests/compiler-plugin-tests/src/main/kotlin/kotlinx/rpc/codegen/test/TestRpcClient.kt b/tests/compiler-plugin-tests/src/main/kotlin/kotlinx/rpc/codegen/test/TestRpcClient.kt index eeb8d11cb..86bd99bb6 100644 --- a/tests/compiler-plugin-tests/src/main/kotlin/kotlinx/rpc/codegen/test/TestRpcClient.kt +++ b/tests/compiler-plugin-tests/src/main/kotlin/kotlinx/rpc/codegen/test/TestRpcClient.kt @@ -5,8 +5,7 @@ package kotlinx.rpc.codegen.test import kotlinx.coroutines.* -import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.rpc.RpcCall import kotlinx.rpc.RpcClient @@ -20,25 +19,8 @@ object TestRpcClient : RpcClient { return "call_42" as T } - @OptIn(DelicateCoroutinesApi::class) - @Suppress("detekt.GlobalCoroutineUsage") - override fun callAsync(serviceScope: CoroutineScope, call: RpcCall): Deferred { - val callable = call.descriptor.getCallable(call.callableName) - ?: error("No callable found for ${call.callableName}") - - val value = when (callable.name) { - "plainFlow" -> flow { emit("registerPlainFlowField_42") } - - "sharedFlow" -> MutableSharedFlow(1).also { - GlobalScope.launch { it.emit("registerSharedFlowField_42") } - } - - "stateFlow" -> MutableStateFlow("registerStateFlowField_42") - - else -> error("Unknown callable name: ${call.callableName}") - } - - return CompletableDeferred(value as T) + override fun callServerStreaming(call: RpcCall): Flow { + return flow { emit("stream_42" as T) } } override fun provideStubContext(serviceId: Long): CoroutineContext { diff --git a/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/RpcInternalConcurrentHashMap.kt b/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/RpcInternalConcurrentHashMap.kt index 6fb7d6672..1818a1709 100644 --- a/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/RpcInternalConcurrentHashMap.kt +++ b/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/RpcInternalConcurrentHashMap.kt @@ -14,6 +14,8 @@ public interface RpcInternalConcurrentHashMap { put(key, value) } + public fun merge(key: K, value: V, remappingFunction: (V, V) -> V): V? + public fun computeIfAbsent(key: K, computeValue: () -> V): V public operator fun get(key: K): V? diff --git a/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/SynchronizedHashMap.kt b/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/SynchronizedHashMap.kt index ae588405c..502ec6e34 100644 --- a/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/SynchronizedHashMap.kt +++ b/utils/src/commonMain/kotlin/kotlinx/rpc/internal/utils/map/SynchronizedHashMap.kt @@ -14,6 +14,18 @@ internal class SynchronizedHashMap : RpcInternalConcurrentHashM map.put(key, value) } + override fun merge(key: K, value: V, remappingFunction: (V, V) -> V): V? = synchronized(this) { + val old = map[key] + if (old == null) { + map[key] = value + value + } else { + val new = remappingFunction(old, value) + map[key] = new + new + } + } + override fun computeIfAbsent(key: K, computeValue: () -> V): V = synchronized(this) { map[key] ?: computeValue().also { map[key] = it } } diff --git a/utils/src/jvmMain/kotlin/kotlinx/rpc/internal/utils/map/ConcurrentHashMap.jvm.kt b/utils/src/jvmMain/kotlin/kotlinx/rpc/internal/utils/map/ConcurrentHashMap.jvm.kt index 5aa01899f..6a0e5c507 100644 --- a/utils/src/jvmMain/kotlin/kotlinx/rpc/internal/utils/map/ConcurrentHashMap.jvm.kt +++ b/utils/src/jvmMain/kotlin/kotlinx/rpc/internal/utils/map/ConcurrentHashMap.jvm.kt @@ -20,6 +20,10 @@ private class ConcurrentHashMapJvm(initialSize: Int) : RpcInter return map.put(key, value) } + override fun merge(key: K, value: V, remappingFunction: (V, V) -> V): V? { + return map.merge(key, value) { old, new -> remappingFunction(old, new) } + } + override fun computeIfAbsent(key: K, computeValue: () -> V): V { return map.computeIfAbsent(key) { computeValue() } }