package net.sergeych.bipack

import kotlinx.datetime.Instant
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.encoding.CompositeDecoder
import kotlinx.serialization.modules.EmptySerializersModule
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.serializer
import net.sergeych.bintools.*

/**
 * Decode BiPack format. Note that it relies on [DataSource] so can throw [DataSource.EndOfData]
 * excpetion. Specific frames when used can throw [InvalidFrameException] and its derivatives.e
 */
@Suppress("UNCHECKED_CAST")
class BipackDecoder(
    val input: DataSource, var elementsCount: Int = 0, val isCollection: Boolean = false,
    val hasFixedSize: Boolean = false,
) : AbstractDecoder() {
    private var elementIndex = 0

    private var nextIsUnsigned = false
    private var fixedSize = -1
    private var fixedNumber = false

    override val serializersModule: SerializersModule = EmptySerializersModule()
    override fun decodeBoolean(): Boolean = input.readByte().toInt() != 0
    override fun decodeByte(): Byte = input.readByte()
    override fun decodeShort(): Short =
        if (fixedNumber) input.readI16()
        else if (nextIsUnsigned)
            input.readNumber<UInt>().toShort()
        else
            input.readNumber()

    override fun decodeInt(): Int =
        if (fixedNumber) input.readI32()
        else if (nextIsUnsigned) input.readNumber<UInt>().toInt() else input.readNumber()

    override fun decodeLong(): Long =
        if (fixedNumber) input.readI64()
        else if (nextIsUnsigned) input.readNumber<ULong>().toLong() else input.readNumber()

    override fun decodeFloat(): Float = input.readFloat()
    override fun decodeDouble(): Double = input.readDouble()
    override fun decodeChar(): Char = Char(input.readNumber<UInt>().toInt())


    fun readBytes(): ByteArray {
        val length = input.readNumber<UInt>()
        return input.readBytes(length.toInt())
    }

    override fun decodeString(): String = readBytes().decodeToString()
    override fun decodeEnum(enumDescriptor: SerialDescriptor): Int = input.readNumber<UInt>().toInt()

    override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
        if (elementIndex >= elementsCount)
            return CompositeDecoder.DECODE_DONE
        nextIsUnsigned = false
        for (a in descriptor.getElementAnnotations(elementIndex)) {
            when (a) {
                is Unsigned -> nextIsUnsigned = true
                is FixedSize -> fixedSize = a.size
                is Fixed -> fixedNumber = true
            }
        }
        return elementIndex++
    }

    override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
        return if (deserializer == Instant.serializer())
            Instant.fromEpochMilliseconds(decodeLong()) as T
        else
            super.decodeSerializableValue(deserializer)
    }


    override fun decodeSequentially(): Boolean = isCollection

    override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
        val isCollection = descriptor.kind == StructureKind.LIST || descriptor.kind == StructureKind.MAP

        val source = if (descriptor.annotations.any { it is CrcProtected })
            CRC32Source(input)
        else
            input

        // Note: we should read from 'source' explicitely as it might ve
        // CRC-calculating one, and the fields below are CRC protected too:
        var count = if (fixedSize >= 0) fixedSize else descriptor.elementsCount
        for (a in descriptor.annotations) {
            if (a is Extendable)
                count = source.readVarUInt().toInt()
            else if (a is Framed) {
                val code = CRC.crc32(descriptor.serialName.encodeToByteArray())
                // if we fail to read CRC, it is IO error, so DataSource.EndOfData will be
                // thrown here, and it is better than invalid frame exception:
                val actual = source.readU32()
                if (code != actual)
                    throw InvalidFrameHeaderException()
            }
        }
//        println("bestr ${descriptor.serialName} d/r ${descriptor.elementsCount}/$count")
        return BipackDecoder(source, count, isCollection, fixedSize >= 0)
    }

    override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
        return if (hasFixedSize)
            elementsCount
        else
            input.readNumber<UInt>().toInt()
    }

    override fun endStructure(descriptor: SerialDescriptor) {
        if (input is CRC32Source && descriptor.annotations.any { it is CrcProtected }) {
            val actual = input.crc
            val expected = input.readU32()
            if (actual != expected)
                throw InvalidFrameCRCException()
        }
        super.endStructure(descriptor)
    }

    override fun decodeNotNullMark(): Boolean = try {
        decodeBoolean()
    } catch (_: DataSource.EndOfData) {
        false
    }

    @ExperimentalSerializationApi
    override fun decodeNull(): Nothing? = null

    companion object {
        fun <T> decode(source: DataSource, deserializer: DeserializationStrategy<T>): T =
            BipackDecoder(source).decodeSerializableValue(deserializer)

        @Suppress("unused")
        inline fun <reified T> decode(source: DataSource): T = decode(source, serializer())
        inline fun <reified T> decode(source: ByteArray): T =
            decode(source.toDataSource(), serializer())
    }
}

inline fun <reified T> ByteArray.decodeFromBipack() = BipackDecoder.decode<T>(this)
