import { Flasher } from "./flash"
import { wait } from "./utils"

const DEVICE_FILTERS = [
  {
    productId: 0xdf11,
    vendorId: 0x0483
  },
  {
    productId: 0x0791, // Ignition voyager stm32
    vendorId: 0x3297
  },
  {
    productId: 0x1791, // Ignition voyager gd32
    vendorId: 0x3297
  },
  {
    productId: 0x2000, // Ergodox EZ ST Glow
    vendorId: 0x3297
  },
  {
    productId: 0x2001, // Ergodox EZ ST Shine
    vendorId: 0x3297
  },
  {
    productId: 0x2002, // Ergodox EZ ST Shine
    vendorId: 0x3297
  }
]

const IGNITION_STM32_PIDS = [0x0791, 0x2000, 0x2001, 0x2002]
const IGNITION_GD32_PIDS = [0x1791]
const BLOCK_SIZE = 2048
const DFU_START_ADDRESS = 0x08000000
const DFU_END_ADDRESS = 0x08040000
const DFU_RESET_BYTES = DFU_END_ADDRESS - DFU_START_ADDRESS
const DFU_START_ADDRESS_IGNITION = 0x08002000
const DFU_END_ADDRESS_IGNITION = 0x08042000
const DFU_CONF = 0x01
const DFU_IFACE = 0x00
const DFU_SUFFIX_LENGTH = 16

const DFU_DNLOAD = 0x01
const DFU_GETSTATUS = 0x03
const DFU_CLRSTATUS = 0x04
const DFU_GETSTATE = 0x05
const DFU_SET_ADDRESS = 0x21
const DFU_ERASE_FLASH = 0x41

const DFU_IDLE = 0x05
const DFU_ERROR = 10

type stateCallback = (bytes: number) => void

class DFUFlasher implements Flasher {
  _devHandle: USBDevice | null = null
  _progressCallBack: stateCallback
  _totalBytesCallback: stateCallback
  _doubleCompilation: boolean

  startAddress = DFU_START_ADDRESS

  constructor(
    progressCallBack: stateCallback,
    totalBytesCallback: stateCallback,
    doubleCompilation: boolean
  ) {
    this._progressCallBack = progressCallBack
    this._totalBytesCallback = totalBytesCallback
    this._doubleCompilation = doubleCompilation
  }

  //Private members
  async _transferIN(
    bRequest: number,
    wLength: number,
    wValue = 0
  ): Promise<DataView> {
    const res = await this._devHandle?.controlTransferIn(
      {
        requestType: "class",
        recipient: "interface",
        request: bRequest,
        value: wValue,
        index: DFU_IFACE
      },
      wLength
    )
    if (res && res.status == "ok") return Promise.resolve(res.data!)
    return Promise.reject("Control transfer failed.")
  }

  async _transferOUT(bRequest: number, data?: ArrayBuffer, wValue = 0) {
    const res = await this._devHandle?.controlTransferOut(
      {
        requestType: "class",
        recipient: "interface",
        request: bRequest,
        value: wValue,
        index: DFU_IFACE
      },
      data
    )
    if (res && res.status == "ok") return Promise.resolve(res.bytesWritten!)
    return Promise.reject("Control transfer failed.")
  }

  async _getStatus() {
    const data = await this._transferIN(DFU_GETSTATUS, 6)
    return {
      status: data.getUint8(0),
      pollTimeout: data.getUint32(1, true) & 0xffffff,
      state: data.getUint8(4)
    }
  }

  async _clearStatus() {
    await this._transferOUT(DFU_CLRSTATUS)
  }

  async _getState() {
    const data = await this._transferIN(DFU_GETSTATE, 1)
    return data.getUint8(0)
  }

  async _download(block: ArrayBuffer, blockNum: number): Promise<number> {
    return this._transferOUT(DFU_DNLOAD, block, blockNum)
  }

  async _waitIdle() {
    let status
    let ticks = 0
    const MAX_TICK = 100
    do {
      status = await this._getStatus()
      await wait(status.pollTimeout)
    } while (
      ticks++ < MAX_TICK &&
      this._devHandle &&
      this._devHandle.opened &&
      status.state != DFU_IDLE
    )
  }

  async _dfu_command(command: number, param = 0x00, len = 0) {
    const buff = new ArrayBuffer(len + 1)
    const view = new DataView(buff)
    view.setUint8(0, command)
    if (len == 1) {
      view.setUint8(1, param)
    } else if (len == 4) {
      view.setUint32(1, param, true)
    }

    await this._download(buff, 0)
  }

  async _dfuReboot() {
    const buff = new ArrayBuffer(0)
    await this._download(buff, 0)
    await this._getStatus()
  }

  async _eraseFlash(startAddress: number, endAddress: number) {
    let addr = startAddress
    let bytesErased = 0
    while (addr < endAddress) {
      await this._dfu_command(DFU_ERASE_FLASH, addr, 4)
      await this._waitIdle()
      addr += BLOCK_SIZE
      bytesErased += BLOCK_SIZE
      this._progressCallBack(bytesErased)
    }
  }

  async _setAddress(address: number) {
    await this._dfu_command(DFU_SET_ADDRESS, address, 4)
    await this._waitIdle()
  }

  _parseDFUSuffix(suffix: Uint8Array): USBDeviceFilter | null {
    const d = String.fromCharCode(suffix[10])
    const f = String.fromCharCode(suffix[9])
    const u = String.fromCharCode(suffix[8])

    if (d == "D" && f == "F" && u == "U") {
      const vendorId = (suffix[5] << 8) + suffix[4]
      const productId = (suffix[3] << 8) + suffix[2]
      return { vendorId, productId }
    } else {
      return null
    }
  }

  _extractSTM32DFUSuffix(firmware: ArrayBuffer): ArrayBuffer {
    const fileSize = firmware.byteLength
    const suffixBuffer = firmware.slice(fileSize - DFU_SUFFIX_LENGTH, fileSize)
    const suffix = new Uint8Array(suffixBuffer)
    const targetFilter = this._parseDFUSuffix(suffix)

    if (!targetFilter) {
      throw new Error("DFU file suffix invalid.")
    }

    if (
      !DEVICE_FILTERS.some(
        (filter) =>
          filter.vendorId == targetFilter.vendorId &&
          filter.productId == targetFilter.productId
      )
    ) {
      throw new Error("DFU file suffix usb ids mismatch.")
    }

    return firmware.slice(0, fileSize - DFU_SUFFIX_LENGTH)
  }

  // In the case of an ignition bootloader, the firmware is built for stm32 and gd32, as such we need to extract the correct firmware for the claimed device.
  // To do so we seek the DFU Suffix and check the VID/PID, if they match the claimed device, we return the firmware, otherwise we throw an error.
  _extractIgnitionFirmare(firmware: ArrayBuffer): ArrayBuffer {
    const fileSize = firmware.byteLength
    let stm32SuffixPos = -1
    for (let i = 0; i <= fileSize - DFU_SUFFIX_LENGTH; i++) {
      const suffixBuffer = firmware.slice(i, i + DFU_SUFFIX_LENGTH)
      const suffix = new Uint8Array(suffixBuffer)

      const targetFilter = this._parseDFUSuffix(suffix)
      if (targetFilter) {
        // Once we found the stm32 suffix, we can infer the stm32 and gd32 firmware bytes.
        if (IGNITION_STM32_PIDS.some((p) => p == targetFilter.productId)) {
          console.info("STM32 firmware suffix found at sector " + i + ".")
          stm32SuffixPos = i
          break
        }
      }
    }
    if (stm32SuffixPos != -1) {
      if (this._isIgnitionSTM32()) {
        console.info("Extracting STM32 firmware")
        //Firmware starts at 0 and ends at stm32SuffixPos
        return firmware.slice(0, stm32SuffixPos)
      }
      if (this._isIgnitionGD32()) {
        console.info("Extracting GD32 firmware")
        //Firmware starts at the end of the stm32 suffix and ends at stm32SuffixPos
        const gd32Firmware = firmware.slice(
          stm32SuffixPos + DFU_SUFFIX_LENGTH,
          fileSize
        )
        const suffixBuffer = gd32Firmware.slice(
          gd32Firmware.byteLength - DFU_SUFFIX_LENGTH,
          gd32Firmware.byteLength
        )
        const suffix = new Uint8Array(suffixBuffer)
        const targetFilter = this._parseDFUSuffix(suffix)

        if (!targetFilter) {
          throw new Error("DFU file suffix invalid.")
        }

        if (!IGNITION_GD32_PIDS.some((p) => p == targetFilter.productId)) {
          throw new Error("No valid DFU suffix found.")
        }

        return gd32Firmware.slice(
          0,
          gd32Firmware.byteLength - DFU_SUFFIX_LENGTH
        )
      }
    }

    // If we didn't find a suffix, we throw an error
    throw new Error("DFU file suffix invalid.")
  }

  //Public members
  get opened(): boolean {
    if (this._devHandle) return this._devHandle.opened
    return false
  }

  async close() {
    if (this._devHandle) await this._devHandle.close()
  }

  async claim() {
    this._devHandle = await navigator.usb.requestDevice({
      filters: DEVICE_FILTERS
    })
    await this._devHandle.open()
    await this._devHandle.selectConfiguration(DFU_CONF)
    await this._devHandle.claimInterface(DFU_IFACE)
    console.info("DFU device claimed.")
    if (this._isIgnition()) {
      const version =
        this._devHandle.deviceVersionMajor +
        "." +
        this._devHandle.deviceVersionMinor +
        "." +
        this._devHandle.deviceVersionSubminor
      console.info("Ignition bootloader detected, version: " + version)
    } else {
      console.info("STMicro DFU Bootloader detected.")
    }
  }

  async flash(firmware: ArrayBuffer) {
    let firmwareData: ArrayBuffer
    if (this._isIgnition() && this._doubleCompilation) {
      firmwareData = this._extractIgnitionFirmare(firmware)
    } else {
      firmwareData = this._extractSTM32DFUSuffix(firmware)
    }

    const status = await this._getStatus()

    this._totalBytesCallback(firmwareData.byteLength + DFU_RESET_BYTES)
    // Clear device status
    if (status.state == DFU_ERROR) {
      console.info("Device in error state, clearing status.")
      await this._clearStatus()
    }

    let startAddress =
      this._isIgnition() == true
        ? DFU_START_ADDRESS_IGNITION
        : DFU_START_ADDRESS

    let endAddress =
      this._isIgnition() == true ? DFU_END_ADDRESS_IGNITION : DFU_END_ADDRESS

    // Send erase command
    console.info("Erasing flash.")
    await this._eraseFlash(startAddress, endAddress)

    let bytesSent = 0
    let totalBytes = firmwareData.byteLength

    // Ignition boards have a different start address
    console.info("Setting start address to: 0x0" + startAddress.toString(16))

    // Send firmare bytes
    console.info("Flashing bytes.")
    const flashedblocks = []
    while (bytesSent < totalBytes) {
      const bytesLeft = totalBytes - bytesSent
      const chunckSize = Math.min(bytesLeft, BLOCK_SIZE)

      await this._setAddress(startAddress)

      const bytesWritten = await this._download(
        firmwareData.slice(bytesSent, bytesSent + chunckSize),
        2
      )
      await this._waitIdle()
      bytesSent += bytesWritten
      startAddress += chunckSize
      this._progressCallBack(bytesSent + DFU_RESET_BYTES)
      flashedblocks.push({
        address: "0x0" + startAddress.toString(16),
        status: "✅"
      })
    }
    console.table(flashedblocks)
    console.info("Flashing complete, rebooting keyboard.")

    //Send the reboot packet
    this._dfuReboot()
  }

  getBootloaderVersion(): string {
    if (this._devHandle) {
      return (
        this._devHandle.deviceVersionMajor +
        "." +
        this._devHandle.deviceVersionMinor +
        "." +
        this._devHandle.deviceVersionSubminor
      )
    }
    return ""
  }

  _isIgnitionSTM32(): boolean {
    if (!this._devHandle) return false
    return IGNITION_STM32_PIDS.includes(this._devHandle!.productId)
  }
  _isIgnitionGD32(): boolean {
    return this._devHandle?.productId == 0x1791
  }
  _isIgnition(): boolean {
    return this._isIgnitionSTM32() || this._isIgnitionGD32()
  }
}

export default DFUFlasher
