package net.sergeych.bipack

import kotlinx.datetime.Instant
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractEncoder
import kotlinx.serialization.encoding.CompositeEncoder
import kotlinx.serialization.modules.EmptySerializersModule
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.serializer
import net.sergeych.bintools.*

class BipackEncoder(val output: DataSink) : AbstractEncoder() {

    private var nextIsUnsigned = false
    private var fixedSize: Int = -1
    private var fixedNumber: Boolean = false

    override fun encodeElement(descriptor: SerialDescriptor, index: Int): Boolean =
        super.encodeElement(descriptor, index).also {
            nextIsUnsigned = false
            for (a in descriptor.getElementAnnotations(index)) {
                when (a) {
                    is Unsigned -> nextIsUnsigned = true
                    is FixedSize -> fixedSize = a.size
                    is Fixed -> fixedNumber = true
                }
            }
        }

    override val serializersModule: SerializersModule = EmptySerializersModule()
    override fun encodeBoolean(value: Boolean) = output.writeByte(if (value) 1 else 0)
    override fun encodeByte(value: Byte) = output.writeByte(value.toInt())
    override fun encodeShort(value: Short) =
        if (fixedNumber) output.writeI16(value)
        else if (nextIsUnsigned)
            output.writeNumber(value.toUShort())
        else
            output.writeNumber(value)

    override fun encodeInt(value: Int) =
        if (fixedNumber)
            output.writeI32(value)
        else if (nextIsUnsigned) output.writeNumber(value.toUInt())
        else output.writeNumber(value)

    fun encodeUInt(value: UInt) = output.writeNumber(value)
    override fun encodeLong(value: Long) =
        if (fixedNumber)
            output.writeI64(value)
        else if (nextIsUnsigned)
            output.writeNumber(value.toULong())
        else
            output.writeNumber(value)

    override fun encodeFloat(value: Float) = output.writeFloat(value)
    override fun encodeDouble(value: Double) = output.writeDouble(value)
    override fun encodeChar(value: Char) = output.writeNumber(value.code.toUInt())
    override fun encodeString(value: String) {
//        output.writeUTF(value)
        writeBytes(value.encodeToByteArray())
    }

    fun writeBytes(value: ByteArray) {
        output.writeNumber(value.size.toUInt())
        output.writeBytes(value)
    }

    override fun encodeEnum(enumDescriptor: SerialDescriptor, index: Int) = output.writeNumber(index.toUInt())

    override fun beginCollection(descriptor: SerialDescriptor, collectionSize: Int): CompositeEncoder {
        if (fixedSize < 0)
            encodeUInt(collectionSize.toUInt())
        else if (collectionSize != fixedSize) {
            throw WrongCollectionSize("collection size is $collectionSize while fixed size of $fixedSize is required")
        }
        return this
    }

    override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
        if (value is Instant) encodeLong(value.toEpochMilliseconds())
        else super.encodeSerializableValue(serializer, value)
    }

    override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
        // frame protection should start before anything else:
        val sink = if (descriptor.annotations.any { it is CrcProtected })
            CRC32Sink(output)
        else
            output
        // now it is safe to process anything else using `sink`. not the output!
        for (a in descriptor.annotations) {
            if (a is Framed) {
                sink.writeU32(
                    CRC.crc32(descriptor.serialName.encodeToByteArray())
                )
            } else if (a is Extendable) {
                sink.writeVarUInt(descriptor.elementsCount.toUInt())
            }
        }
        return BipackEncoder(sink)
    }

    override fun endStructure(descriptor: SerialDescriptor) {
        if (output is CRC32Sink && descriptor.annotations.any { it is CrcProtected }) {
            output.writeU32(output.crc)
        }
        super.endStructure(descriptor)
    }

    override fun encodeNull() = encodeBoolean(false)
    override fun encodeNotNullMark() = encodeBoolean(true)

    companion object {
        fun <T> encode(serializer: SerializationStrategy<T>, value: T, sink: DataSink) {
            val encoder = BipackEncoder(sink)
            encoder.encodeSerializableValue(serializer, value)
        }

        fun <T> encode(serializer: SerializationStrategy<T>, value: T): ByteArray =
            ArrayDataSink().also { encode(serializer, value, it) }.toByteArray()

        inline fun <reified T> encode(value: T) = encode(serializer(), value)
        @Suppress("unused")
        inline fun <reified T> encode(value: T, sink: DataSink) = encode(serializer(), value, sink)

    }
}
