package net.sergeych.kiloparsec

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.serializer
import net.sergeych.crypto2.toDump
import net.sergeych.kiloparsec.Transport.Device
import net.sergeych.mp_logger.*
import net.sergeych.utools.pack
import net.sergeych.utools.unpack

/**
 * Divan channel that operates some block [Device] exporting a given [localInterface]
 * to remote callers. [LocalInterface] allows session managing, transmitting exceptions
 * in a scure and multiplatform way and provide local command execution (typed RPC)
 */
class Transport<S>(
    private val device: Device,
    private val localInterface: LocalInterface<S>,
    private val commandContext: S,
) : Loggable by LogTag("TR:$device"), RemoteInterface {


    /**
     * Channel operates using an abstract device, that performs binary block exchange implementing
     * this interface.
     */
    interface Device {

        /**
         * Input blocks. When the device is disconnected, it should send one null to this channel
         * to notify the owner. When [close] is called, the channel should be closed.
         */
        val input: ReceiveChannel<UByteArray>

        /**
         * Send a binary block to a remote party where it should be received and put into [input]
         * channel. If the device is closed, it should close this channel, also by [close].
         */
        val output: SendChannel<UByteArray>

        /**
         * Close input and output and free any resources. The output channel should be flushed if
         * possible. This method must not throw exceptions.
         */
        suspend fun close()
    }

    @Serializable(TransportBlockSerializer::class)
    sealed class Block {
        @Serializable
        data class Call(val id: UInt, val name: String, val packedArgs: UByteArray) : Block() {
            override fun equals(other: Any?): Boolean {
                if (this === other) return true
                if (other !is Call) return false
                if (id != other.id) return false
                if (name != other.name) return false
                if (!(packedArgs contentEquals other.packedArgs)) return false

                return true
            }

            override fun hashCode(): Int {
                var result = id.hashCode()
                result = 31 * result + name.hashCode()
                result = 31 * result + packedArgs.contentHashCode()
                return result
            }
        }

        @Serializable
        data class Response(val forId: UInt, val packedResult: UByteArray) : Block()

        @Serializable
        data class Error(val forId: UInt, val code: String, val text: String? = null, val extra: UByteArray? = null) :
            Block() {
            val message by lazy { text ?: "remote exception: $code" }
        }
    }

    private val access = Mutex()
    private var lastId = 0u
    private val calls = mutableMapOf<UInt, CompletableDeferred<UByteArray>>()
    var isClosed: Boolean = false

    /**
     * Send a call block for a command and packed args and return packed result if it is not an error
     * @throws RemoteInterface.RemoteException if the remote call caused an exception. Normally use [call] instead.
     * @throws RemoteInterface.ClosedException
     */
    private suspend fun sendCallBlock(name: String, packedArgs: UByteArray): UByteArray {
        if (isClosed) throw RemoteInterface.ClosedException()

        val b: Block
        val deferred = CompletableDeferred<UByteArray>()

        // We need to shield calls and lastID with mutex, but nothing more:
        access.withLock {
            if (isClosed) throw RemoteInterface.ClosedException()
            b = Block.Call(++lastId, name, packedArgs)
            calls[b.id] = deferred
        }

        // now we have mutex freed so we can call:
        val r = runCatching {  device.output.send(pack(b).also { debug { ">>>\n${it.toDump()}" } }) }
        if (!r.isSuccess) {
            r.exceptionOrNull()?.let {
                exception { "failed to send output block" to it }
            } ?: run {
                error { "It should not happen: empty exception on block send failure" }
                throw RuntimeException("unexpected failure in sending block")
            }
            deferred.completeExceptionally(RemoteInterface.ClosedException())
        }

        // it returns packed result or throws a proper error:
        return deferred.await()
    }

    /**
     * Call the remote procedure with specified args and return its result
     */
    override suspend fun <A, R> call(cmd: Command<A, R>, args: A): R {
        val result = sendCallBlock(cmd.name, pack(cmd.argsSerializer, args))
        return unpack(cmd.resultSerializer, result)
    }

    /**
     * Start running the transport. This function suspends until the transport is closed
     * normally or by error. If you need to cancel it prematurely, cancel the coroutine
     * it is started in. This approach allows using transport with lifespan connected to the
     * calling coroutine which greatly simplifies its usage in popular asyn platofrms like
     * a ktor client and server, compose multiplatform, etc.
     */
    suspend fun run() {
        coroutineScope {
            debug { "awaiting incoming blocks" }
            while (isActive && !isClosed) {
                try {
                    device.input.receive().let { packed ->
                        debug { "<<<\n${packed.toDump()}" }
                        val b = unpack<Block>(packed)
                        debug { "<<$ $b" }
                        debug { "access state: ${access.isLocked}" }
                        when (b) {
                            is Block.Error -> access.withLock {
                                val error = localInterface.decodeError(b)
                                warning { "decoded error: ${error::class.simpleName}: $error" }
                                calls.remove(b.forId)?.completeExceptionally(localInterface.decodeError(b))
                                    ?: warning { "error handler not found for ${b.forId}" }
                                info { "error processed"}
                            }

                            is Block.Response -> access.withLock {
                                calls.remove(b.forId)?.let {
                                    debug { "activating wait handle for ${b.forId}" }
                                    it.complete(b.packedResult)
                                }
                                    ?: warning { "wait handle not found for ${b.forId}" }
                            }

                            is Block.Call -> launch {
                                try {
                                    send(
                                        Block.Response(
                                            b.id,
                                            localInterface.execute(commandContext, b.name, b.packedArgs)
                                        )
                                    )
                                } catch (x: RemoteInterface.ClosedException) {
                                    // strange case: handler throws closed?
                                    error { "not supported: command handler for $b has thrown ClosedException" }
                                    send(Block.Error(b.id, "UnexpectedException", x.message))
                                } catch (x: RemoteInterface.RemoteException) {
                                    send(Block.Error(b.id, x.code, x.text, x.extra))
                                } catch (t: Throwable) {
                                    send(Block.Error(b.id, "UnknownError", t.message))
                                }
                                    .also { debug { "command executed: ${b.name}" } }
                            }
                        }
                    }
                } catch (_: CancellationException) {
                    info { "loop is cancelled" }
                    isClosed = true
                } catch (t: Throwable) {
                    exception { "channel closed on error" to t }
                    info { "isa? $isActive / $isClosed" }
                    runCatching { device.close() }
                    isClosed = true
                }
            }
            access.withLock {
                isClosed = true
                for (c in calls.values) c.completeExceptionally(RemoteInterface.ClosedException())
                calls.clear()
            }
            debug { "no more active: $isActive / ${calls.size}" }
        }
        info { "exiting transport loop" }
    }

    private suspend fun send(block: Block) {
        device.output.send(pack(block))
    }

}

object TransportBlockSerializer : KSerializer<Transport.Block> {
    override val descriptor: SerialDescriptor = PrimitiveSerialDescriptor("TransportBlock", PrimitiveKind.INT)

    override fun serialize(encoder: Encoder, value: Transport.Block) {
        when (value) {
            is Transport.Block.Call -> {
                encoder.encodeByte(0)
                encoder.encodeSerializableValue(serializer<Transport.Block.Call>(), value)
            }

            is Transport.Block.Error -> {
                encoder.encodeByte(1)
                encoder.encodeSerializableValue(serializer<Transport.Block.Error>(), value)
            }

            is Transport.Block.Response -> {
                encoder.encodeByte(2)
                encoder.encodeSerializableValue(serializer<Transport.Block.Response>(), value)
            }
        }
    }


    override fun deserialize(decoder: Decoder): Transport.Block =
        when( val id = decoder.decodeByte().toInt()) {
            0 -> decoder.decodeSerializableValue(serializer<Transport.Block.Call>())
            1 -> decoder.decodeSerializableValue(serializer<Transport.Block.Error>())
            2 -> decoder.decodeSerializableValue(serializer<Transport.Block.Response>())
            else -> throw RemoteInterface.InvalidDataException("wrong block type: $id")
        }
}