from dataclasses import dataclass from cocotb.queue import Queue import cocotb import cocotb.handle from cocotb.clock import Clock from cocotb.triggers import Trigger, First, Event, Timer, RisingEdge, FallingEdge import logging import random @dataclass class SpiInterface: csn: cocotb.handle.LogicObject sck: cocotb.handle.LogicObject miso: cocotb.handle.LogicObject mosi: cocotb.handle.LogicObject # class SignalMonitor: # def expect_change_to_after() -> SignalMonitor: # def expect_stable_for() -> SignalMonitor: # def expect_value() -> SignalMonitor: # def step() -> SignalMonitor: # async def coroutine(): # class ClockedSignalMonitor: # def __init__(clock, signal, sampling = RisingEdge): # self.signal = signal # self.sampling = sampling # self.events = [ [] ] # def expect_change_to_after() -> SignalMonitor: # def expect_stable_for() -> SignalMonitor: # def expect_value() -> SignalMonitor: # def set_value() -> SignalMonitor: # def set_callback() -> SignalMonitor: # def step() -> SignalMonitor: # self.events.append([]) # async def coroutine(): # while True: # await self.sampling # if self.events.len() == 0: # continue # if self.events[0].len() == 0: # continue # events = self.events.pop(0) # for event in events: @dataclass class SpiConfig: transaction_len: int sampling: Trigger shifting: Trigger sck_period: int sck_period_unit: str class SpiSlave: def __init__(self, spi: SpiInterface, config: SpiConfig): self._transaction = Event("spi_slave_transaction") self._idle = Event("spi_idle") self._wakeup = Event("spi_slave_wakeup") self._transaction.clear() self._idle.clear() self._wakeup.clear() self.log = logging.getLogger("SpiSlave") self.config = config self.transactions = Queue() self.received = Queue() self.sck = spi.sck self.csn = spi.csn self.tx = spi.miso self.rx = spi.mosi async def send_data(self, data: int, len: int): item = (data, len) await self.transactions.put(item) async def received_data(self) -> int: return await self.received.get() async def expect_transaction_in(self, max: int, unit: str = "ns"): self.data = (max, unit) self._wakeup.set() async def wait_one(self): self._transaction.clear() await self._transaction.wait() async def wait_all(self): await self._idle.wait() async def coroutine(self): while True: self._idle.set() await self._wakeup.wait() # first = await First(wakeup, csn_falling) # if first == csn_falling: # self.log.error("CSN fell when transaction was unexpected.") # continue self.log.info("Got new transaction expectation.") self._wakeup.clear() max, unit = self.data self._idle.clear() csn_falling = FallingEdge(self.csn) first = await First(csn_falling, Timer(max, unit)) if first != csn_falling: self.log.error(f"CSN did not fall in time ({max} {unit})!") continue # csn fell while not self.transactions.empty(): # expect clock data, len = await self.transactions.get() total_len = len self._transaction.set() self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1) data = data << 1 received = 0 got_sampling = False while len > 0: sampling = self.config.sampling(self.sck) shifting = self.config.shifting(self.sck) timeout = Timer(self.config.sck_period * 4, self.config.sck_period_unit) res = await First(sampling, shifting, timeout) if res == timeout: self.log.error("Got no sck edge in time!") continue if res == shifting: if got_sampling: self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1) data = data << 1 else: got_sampling = True received = received << 1 received |= int(self.rx.value) len -= 1 await self.received.put(received) self.log.info(f"Received {received}") # now wait for csn rising timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit) csn_rising = RisingEdge(self.csn) res = await First(csn_rising, timeout) if res == timeout: self.log.error("Got no rising edge on csn") continue self.tx.value = cocotb.handle.Release() # good, continue self._transaction.set() async def init(dut, master: int = 1, tx_en: int = 1): dut._log.info("Init started!") dut.miso_io.value = 0; dut.rst_in.value = 0; dut.clock_polarity_i.value = 0; dut.clock_phase_i.value = 0; dut.size_sel_i.value = 0; dut.div_sel_i.value = 0; dut.pulse_csn_i.value = 0; dut.rx_block_on_full_i.value = 0; dut.rx_ready_i.value = 0; dut.tx_valid_i.value = 0; dut.clear_lost_rx_data_i.value = 0; dut.rx_en_i.value = 1; dut.tx_en_i.value = 1; dut.master_i.value = 1; dut.en_i.value = 1; await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) # Release reset dut.rst_in.value = 1; await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) dut._log.info("Init done!") class DutDriver: def __init__(self, dut): self.dut = dut self._log = logging.getLogger("DutDriver") self._received = Queue() self._sync = Event() self._auto_receive = False self._set_ready = False async def receive_data(self): if int(self.dut.rx_valid_o.value) != 1: self._log.error("RX is not valid when receiving data was requested") await FallingEdge(self.dut.clk_i) self.dut.rx_ready_i.value = 1 data = self.dut.rx_data_o.value await FallingEdge(self.dut.clk_i) if int(self.dut.rx_valid_o.value) != 0: self._log.error("RX data stayed valid after receiving!") self.dut.rx_ready_i.value = 0 return data async def send_data(self, data): if int(self.dut.tx_ready_o.value) != 1: self._log.error("TX is not ready when sending data was requested") await FallingEdge(self.dut.clk_i) self.dut.tx_valid_i.value = 1 self.dut.tx_data_i.value = data await FallingEdge(self.dut.clk_i) self.dut.tx_valid_i.value = 0 if int(self.dut.tx_ready_o.value) != 0: self._log.error("TX stayed ready after requesting send of data!") async def send_data_wait(self, data): await FallingEdge(self.dut.clk_i) self.dut.tx_valid_i.value = 1 self.dut.tx_data_i.value = data clk_falling = FallingEdge(self.dut.clk_i) tx_ready_falling = FallingEdge(self.dut.tx_ready_o) res = await First(clk_falling, tx_ready_falling) if res == clk_falling: # TODO timeout await RisingEdge(self.dut.tx_ready_o) await RisingEdge(self.dut.clk_i) # now the data were registered # (note that it's ready, not confirmation # so data are sampled only after it actually is ready # that means the next clock cycle from when ready went to 1) self.dut.tx_valid_i.value = 0 async def auto_receive(self, receive: bool = True, set_ready: bool = True): self._auto_receive = receive self._set_ready = set_ready self._sync.set() async def received_data(self): if self._received.empty(): return None return await self._received.get() async def coroutine(self): while True: if not self._auto_receive: self.dut.rx_ready_i.value = 0 await self._sync.wait() self._sync.clear() if not self._auto_receive: continue if self._set_ready: self.dut.rx_ready_i.value = 1 await RisingEdge(self.dut.rx_valid_o) if int(self.dut.rx_valid_o.value) == 1: await self._received.put(self.dut.rx_data_o.value) should_lose_data = not self._set_ready and self.dut.tx_valid_i.value == 1 await RisingEdge(self.dut.clk_i) await RisingEdge(self.dut.clk_i) if should_lose_data and int(self.dut.err_lost_rx_data_o.value) != 1: self._log.error("Didn't get err lost rx data after rx valid") if should_lose_data: self.dut.clear_lost_rx_data_i.value = 1 await RisingEdge(self.dut.clk_i) await FallingEdge(self.dut.clk_i) self.dut.clear_lost_rx_data_i.value = 0 if int(self.dut.err_lost_rx_data_o.value) != 0: self._log.error("Could not clear err lost rx data") @cocotb.test() async def single_transmit(dut): clk = Clock(dut.clk_i, 5, "ns", impl = "py") interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io) config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns") slave = SpiSlave(interface, config) driver = DutDriver(dut) await cocotb.start(clk.start()) await init(dut) await cocotb.start(slave.coroutine()) await cocotb.start(driver.coroutine()) # From slave point of view rx = random.randint(0, 255) tx = random.randint(0, 255) await slave.send_data(tx, 8) await slave.expect_transaction_in(15, "ns") await driver.send_data(rx) await slave.wait_all() dut_received = await driver.receive_data() assert int(dut_received) & 0xFF == tx # Wait a few clocks, rx data should still stay valid! await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) received = await slave.received_data() assert received & 0xFF == rx await Timer(100, "ns") @cocotb.test() async def multiple_transmits(dut): clk = Clock(dut.clk_i, 5, "ns") interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io) config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns") slave = SpiSlave(interface, config) driver = DutDriver(dut) await cocotb.start(clk.start()) await init(dut) await cocotb.start(slave.coroutine()) await cocotb.start(driver.coroutine()) await driver.auto_receive() count = 5 tx_data = [random.randint(0, 255) for i in range(count)] rx_data = [random.randint(0, 255) for i in range(count)] for tx in tx_data: await slave.send_data(tx, 8) dut._log.info(f"Sending Data from slave: {tx}") dut._log.info("To expect transaction") await slave.expect_transaction_in(15, "ns") for rx in rx_data: dut._log.info(f"Sending Data from master: {rx}") await driver.send_data_wait(rx) await slave.wait_all() # Checks for tx in tx_data: dut_received = await driver.received_data() assert int(dut_received) & 0xFF == tx for rx in rx_data: received = await slave.received_data() assert received & 0xFF == rx await Timer(100, "ns") @cocotb.test() async def lost_rx_data(dut): clk = Clock(dut.clk_i, 5, "ns") interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io) config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns") slave = SpiSlave(interface, config) driver = DutDriver(dut) await cocotb.start(clk.start()) await init(dut) await cocotb.start(slave.coroutine()) await cocotb.start(driver.coroutine()) # Do not confirm reception of data. That will auto check # if lost is asserted and cleared appropriately. await driver.auto_receive(True, False) count = 5 tx_data = [random.randint(0, 255) for i in range(count)] rx_data = [random.randint(0, 255) for i in range(count)] for tx in tx_data: await slave.send_data(tx, 8) await slave.expect_transaction_in(15, "ns") for rx in rx_data: await driver.send_data_wait(rx) await slave.wait_all() # Checks # Even though the DUT thinks data were lost, they were actually # read and should still be valid data. for tx in tx_data: dut_received = await driver.received_data() assert int(dut_received) & 0xFF == tx for rx in rx_data: received = await slave.received_data() assert received & 0xFF == rx await Timer(100, "ns") # All clock phases and polarities # All sizes, divisors # Rx blocking - Can't go to another transmission until data confirmed. # When data read a bit later, and csn pulsing is enabled, the csn should still pulse, before data are obtained # All clock phases and polarities # csn pulse # rx_en off - miso should be ignored, no rx_valid is always 0. Tx works fine # tx_en off - mosi should be Z, tx_ready is always 0. Rx works fine