Dynamic Proxies in Kotlin with Coroutines

Dynamic Proxies in Kotlin with Coroutines

One of the features that make Kotlin stand out is its support for coroutines. Coroutines allow developers to write asynchronous code in a sequential manner, making it easier to reason about and less error-prone. Over the last few years, Coroutines have proved to be an invaluable tool in a Kotlin developer's toolbelt when writing Android applications.

In this blog post, we will explore how to use dynamic proxies with Coroutines support. Dynamic proxies are a powerful paradigm that allows developers to create an object that implements a given interface at runtime. This is useful in scenarios where it's not possible or desirable to create a concrete implementation of an interface.

To create a dynamic proxy in Kotlin, we will use the Proxy.newProxyInstance method from the java.lang.reflect package. This method takes three arguments: the class loader, an array of interfaces to implement, and an invocation handler.

Class loaders are responsible for loading Java classes dynamically to the JVM (Java Virtual Machine) during runtime. We can grab the ClassLoader from the service itself.

The array of interfaces to implement is an array of classes that define the methods that the dynamic proxy should implement.

The invocation handler is an object that implements the InvocationHandler interface. The InvocationHandler interface defines a single method, invoke, which is called whenever a method on the dynamic proxy is called. The invoke method takes in 3 parameters, the proxy itself, the method that is being proxies, and the array of arguments to the method.

To add support for suspend functions to the dynamic proxy, we can create a custom InvocationHandler. The CoroutineInvocationHandler will inspect the last argument in the args array and check for Continuation<>. This is an interface pulled from the kotlin:std-lib package that "represents a continuation after a suspension point that returns a value of type T."

The use cases for this type of setup are numerous, but a few examples that I've personally used this setup for is:

  1. Detecting authentication errors for all API services and automatically re-authorizing the request while still blocking the original request

  2. Logging network statistics for all API services across your application

  3. Wrapping Repository level errors in custom error types (think AuthenticationError, ServerError, DatabaseError)

With all of that being said, we can finally get into some code. First, a few classes/helpers we will use in our InvocationHandler

class CustomException(override val message: String) : RuntimeException(message)
/**
 * Creates an instance of [T] that utilizes our custom [InvocationHandler]
 */
inline fun <reified T : Any> createProxy(
    invocationHandler: InvocationHandler,
): T {
    val service = T::class.java

    return Proxy.newProxyInstance(
        service.classLoader,
        arrayOf(service),
        invocationHandler,
    ) as T
}
/**
 * Factory interface to create [InvocationHandler] around the original interface / service
 */
interface InterfaceInvocationHandlerFactory {
    fun create(
        originalInterface: Any,
    ): InvocationHandler
}

The code above is mostly for convenience in creating and utilizing instances of our interfaces throughout our codebase.

Now, we can look at the actual implementation of InterfaceInvocationHandlerFactory

class InterfaceInvocationHandlerFactoryImpl : InterfaceInvocationHandlerFactory {

    override fun create(
        originalInterface: Any,
    ): InvocationHandler {
        return object : InvocationHandler {

            override fun invoke(
                proxy: Any,
                method: Method,
                args: Array<out Any>?,
            ): Any? {
                val nonNullArgs = args ?: arrayOf()
                val continuation = nonNullArgs.continuation()

                return if (continuation == null) {
                    // non-suspending function, just invoke regularly
                    try {
                        val result = method.invoke(originalInterface, *nonNullArgs)
                        // we could inspect anything that we want on the result at this point
                        result
                    } catch (invocationTargetException: InvocationTargetException) {
                        throw invocationTargetException.cause ?: invocationTargetException
                    }
                } else {
                    // create a wrapper around the original continuation. we want to do this so we can capture the result and
                    // potentially inspect it
                    val wrappedContinuation = object : Continuation<Any?> {
                        override val context: CoroutineContext get() = continuation.context

                        override fun resumeWith(
                            result: Result<Any?>,
                        ) {
                            // here is where we could inspect result for any type of result / error that we'd like.
                            // since we are not doing anything special with it in this example, we can just resume the continuation
                            // with the value
                            continuation.resumeWith(result)
                        }
                    }

                    invokeSuspendFunction(continuation) outer@{
                        // we want to invoke the method with our continuation wrapper instead
                        // of the original continuation so we can inspect the results. So we will
                        // grab the original arguments, and replace the last element with our continuation wrapper
                        val argumentsWithoutContinuation = if (nonNullArgs.isNotEmpty()) {
                            nonNullArgs.take(nonNullArgs.size - 1)
                        } else {
                            nonNullArgs.toList()
                        }

                        val newArgs = argumentsWithoutContinuation + wrappedContinuation

                        try {
                            val result =
                                method.invoke(
                                    originalInterface,
                                    *newArgs.toTypedArray(),
                                )

                            if (result == COROUTINE_SUSPENDED) {
                                // this can happen if the method we are proxying is a suspending.
                                // when that is the case, just return result / COROUTINE_SUSPENDED since they are the same thing
                                result
                            } else {
                                // here is where we could inspect result
                                result
                            }
                        } catch (invocationTargetException: InvocationTargetException) {
                            throw invocationTargetException.cause ?: invocationTargetException
                        }
                    }
                }
            }


            @Suppress("UNCHECKED_CAST")
            fun <T> invokeSuspendFunction(
                continuation: Continuation<*>,
                block: suspend () -> T,
            ): T =
                (block as (Continuation<*>) -> T)(continuation)


            @Suppress("UNCHECKED_CAST")
            private fun Array<*>?.continuation(): Continuation<Any?>? {
                return this?.lastOrNull() as? Continuation<Any?>
            }

        }
    }
}

I've tried to document the code to fully explain everything that is happening, but here are the high-level ideas we are trying to accomplish with our proxy

  1. Check if we are invoking a suspending or non-suspending method by inspecting the last element in the args array. If it is of type Continuation<>, then we can infer it is a suspending function

  2. If it is not a suspending function, we can invoke the method. We could at this point inspect the result and handle anything that we want accordingly.

  3. If it is a suspending function, we want to create a continuation wrapper so we can inspect the result of the suspending function when it is invoked.

Now that we've gone over everything, we can finally use it with some sample code!

Let's first consider a couple of interfaces for our test:

interface ServiceWithoutSuspendingFunction {
    // adds 2 numbers
    fun add(parameterOne: Int, parameterTwo: Int): Int

    // this will always throw an exception
    fun exception()
}

interface ServiceWithSuspendingFunction {
    // adds 2 numbers
    suspend fun add(parameterOne: Int, parameterTwo: Int): Int

    // this will always throw an exception
    suspend fun exception()
}

We can create instances of those interfaces that utilize our proxy by:

fun main(args: Array<String>) {

    val interfaceInvocationHandlerFactoryImpl = InterfaceInvocationHandlerFactoryImpl()


    val serviceWithoutSuspendingFunction = createProxy<ServiceWithoutSuspendingFunction>(
        invocationHandler = interfaceInvocationHandlerFactoryImpl.create(
            object : ServiceWithoutSuspendingFunction {
                override fun add(parameterOne: Int, parameterTwo: Int): Int {
                    return parameterOne + parameterTwo
                }

                override fun exception() {
                    throw CustomException("Error in ServiceWithoutSuspendingFunction")
                }
            }
        )
    )

    val serviceWithSuspendingFunction = createProxy<ServiceWithSuspendingFunction>(
        invocationHandler = interfaceInvocationHandlerFactoryImpl.create(
            object : ServiceWithSuspendingFunction {
                override suspend fun add(parameterOne: Int, parameterTwo: Int): Int {
                    return withContext(Dispatchers.IO) {
                        delay(1000)
                        parameterOne + parameterTwo
                    }
                }

                override suspend fun exception() {
                    throw CustomException("Error in ServiceWithSuspendingFunction")
                }
            }
        )
    )
}

And when we execute the methods on our interfaces:

    runBlocking {
        println("serviceWithoutSuspendingFunction = ${serviceWithoutSuspendingFunction.add(1, 1)}")
        println("serviceWithSuspendingFunction = ${serviceWithSuspendingFunction.add(1, 1)}")

        try {
            serviceWithoutSuspendingFunction.exception()
        } catch (e: Exception) {
            println(e)
        }

        try {
            serviceWithSuspendingFunction.exception()
        } catch (e: Exception) {
            println(e)
        }
    }

We can see that they are invoked and work as expected, and our custom error was caught as expected.

In conclusion, dynamic proxies with suspend function support can be a powerful tool in Kotlin, allowing developers to create objects that implement interfaces at runtime and support asynchronous computations using coroutines. While the topic of dynamic proxies in Kotlin can be a bit more advanced, they can be a useful tool in many scenarios. Hopefully I've been able to show some scenarios where using dynamic proxies would be helpful as well as how to utilize them.

The full source code is available here.