from dataclasses import dataclass from cocotb.queue import Queue import cocotb import cocotb.handle from cocotb.clock import Clock from cocotb.triggers import Trigger, First, Event, Timer, RisingEdge, FallingEdge import logging @dataclass class SpiInterface: csn: cocotb.handle.LogicObject sck: cocotb.handle.LogicObject miso: cocotb.handle.LogicObject mosi: cocotb.handle.LogicObject # class SignalMonitor: # def expect_change_to_after() -> SignalMonitor: # def expect_stable_for() -> SignalMonitor: # def expect_value() -> SignalMonitor: # def step() -> SignalMonitor: # async def coroutine(): # class ClockedSignalMonitor: # def __init__(clock, signal, sampling = RisingEdge): # self.signal = signal # self.sampling = sampling # self.events = [ [] ] # def expect_change_to_after() -> SignalMonitor: # def expect_stable_for() -> SignalMonitor: # def expect_value() -> SignalMonitor: # def set_value() -> SignalMonitor: # def set_callback() -> SignalMonitor: # def step() -> SignalMonitor: # self.events.append([]) # async def coroutine(): # while True: # await self.sampling # if self.events.len() == 0: # continue # if self.events[0].len() == 0: # continue # events = self.events.pop(0) # for event in events: @dataclass class SpiConfig: transaction_len: int sampling: Trigger shifting: Trigger sck_period: int sck_period_unit: str class SpiSlave: def __init__(self, spi: SpiInterface, config: SpiConfig): self._transaction = Event("spi_slave_transaction") self._idle = Event("spi_idle") self._wakeup = Event("spi_slave_wakeup") self._transaction.clear() self._idle.clear() self._wakeup.clear() self.log = logging.getLogger("SpiSlave") self.config = config self.transactions = Queue() self.received = Queue() self.sck = spi.sck self.csn = spi.csn self.tx = spi.miso self.rx = spi.mosi async def send_data(self, data: int, len: int): item = (data, len) await self.transactions.put(item) async def received_data(self) -> int: return await self.received.get() async def expect_transaction_in(self, max: int, unit: str = "ns"): self.data = (max, unit) self._wakeup.set() async def wait_one(self): self._transaction.clear() await self._transaction.wait() async def wait_all(self): await self._idle.wait() async def coroutine(self): while True: self._idle.set() await self._wakeup.wait() # first = await First(wakeup, csn_falling) # if first == csn_falling: # self.log.error("CSN fell when transaction was unexpected.") # continue self.log.info("Got new transaction expectation.") self._wakeup.clear() max, unit = self.data self._idle.clear() csn_falling = FallingEdge(self.csn) first = await First(csn_falling, Timer(max, unit)) if first != csn_falling: self.log.error(f"CSN did not fall in time ({max} {unit})!") continue # csn fell while not self.transactions.empty(): # expect clock data, len = await self.transactions.get() total_len = len self._transaction.set() self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1) data = data << 1 received = 0 got_sampling = False while len > 0: sampling = self.config.sampling(self.sck) shifting = self.config.shifting(self.sck) timeout = Timer(self.config.sck_period * 4, self.config.sck_period_unit) res = await First(sampling, shifting, timeout) if res == timeout: self.log.error("Got no sck edge in time!") continue if res == shifting: if got_sampling: self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1) data = data << 1 else: got_sampling = True received = received << 1 received |= int(self.rx.value) len -= 1 # now wait for csn rising timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit) csn_rising = RisingEdge(self.csn) res = await First(csn_rising, timeout) if res == timeout: self.log.error("Got no rising edge on csn") continue self.tx.value = cocotb.handle.Release() await self.received.put(received) self.log.info(f"Received {received}") # good, continue self._transaction.set() async def init(dut, master: int = 1, tx_en: int = 1): dut._log.info("Init started!") dut.miso_io.value = 0; dut.rst_in.value = 0; dut.clock_polarity_i.value = 0; dut.clock_phase_i.value = 0; dut.size_sel_i.value = 0; dut.div_sel_i.value = 0; dut.pulse_csn_i.value = 0; dut.rx_block_on_full_i.value = 0; dut.rx_ready_i.value = 0; dut.tx_valid_i.value = 0; dut.clear_lost_rx_data_i.value = 0; dut.rx_en_i.value = 1; dut.tx_en_i.value = 1; dut.master_i.value = 1; dut.en_i.value = 1; await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) # Release reset dut.rst_in.value = 1; await FallingEdge(dut.clk_i) await FallingEdge(dut.clk_i) dut._log.info("Init done!") @cocotb.test() async def simple_test(dut): clk = Clock(dut.clk_i, 5, "ns") interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io) config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns") slave = SpiSlave(interface, config) await cocotb.start(clk.start()) await init(dut) await cocotb.start(slave.coroutine()) await slave.send_data(123, 8) await slave.expect_transaction_in(15, "ns") await FallingEdge(dut.clk_i) dut.tx_valid_i.value = 1 dut.tx_data_i.value = 100 await FallingEdge(dut.clk_i) dut.tx_valid_i.value = 0 await slave.wait_one() assert int(dut.rx_valid_o.value) == 1 assert int(dut.rx_data_o.value) & 0xFF == 123 received = await slave.received_data() assert received & 0xFF == 100 await Timer(100, "ns") # Simple one receive, transmit # Check csn goes low # Check 8 sck pulses # All clock phases and polarities # All sizes, divisors # Rx blocking - Can't go to another transmission until data confirmed. # When data read a bit later, and csn pulsing is enabled, the csn should still pulse, before data are obtained # Multiple times receive, transmit without pulse # All clock phases and polarities # Multiple times receive, transmit with pulse # Only transmission # Rx is not valid # Only reception # Z on mosi