@@ 1,48 1,256 @@
+from dataclasses import dataclass
+from cocotb.queue import Queue
import cocotb
-from cocotb.triggers import Timer
+import cocotb.handle
+from cocotb.clock import Clock
+from cocotb.triggers import Trigger, First, Event, Timer, RisingEdge, FallingEdge
+import logging
-@cocotb.test()
-async def simple_test(dut):
- dut.miso_i.value = 1;
- dut.clk_i.value = 0;
+@dataclass
+class SpiInterface:
+ csn: cocotb.handle.LogicObject
+ sck: cocotb.handle.LogicObject
+ miso: cocotb.handle.LogicObject
+ mosi: cocotb.handle.LogicObject
+
+# class SignalMonitor:
+# def expect_change_to_after() -> SignalMonitor:
+# def expect_stable_for() -> SignalMonitor:
+# def expect_value() -> SignalMonitor:
+# def step() -> SignalMonitor:
+
+# async def coroutine():
+
+# class ClockedSignalMonitor:
+# def __init__(clock, signal, sampling = RisingEdge):
+# self.signal = signal
+# self.sampling = sampling
+# self.events = [ [] ]
+
+# def expect_change_to_after() -> SignalMonitor:
+# def expect_stable_for() -> SignalMonitor:
+# def expect_value() -> SignalMonitor:
+# def set_value() -> SignalMonitor:
+
+# def set_callback() -> SignalMonitor:
+
+# def step() -> SignalMonitor:
+# self.events.append([])
+
+# async def coroutine():
+# while True:
+# await self.sampling
+
+# if self.events.len() == 0:
+# continue
+
+# if self.events[0].len() == 0:
+# continue
+
+# events = self.events.pop(0)
+
+# for event in events:
+
+@dataclass
+class SpiConfig:
+ transaction_len: int
+ sampling: Trigger
+ shifting: Trigger
+ sck_period: int
+ sck_period_unit: str
+
+class SpiSlave:
+ def __init__(self, spi: SpiInterface, config: SpiConfig):
+ self._transaction = Event("spi_slave_transaction")
+ self._idle = Event("spi_idle")
+ self._wakeup = Event("spi_slave_wakeup")
+
+ self._transaction.clear()
+ self._idle.clear()
+ self._wakeup.clear()
+
+ self.log = logging.getLogger("SpiSlave")
+
+ self.config = config
+ self.transactions = Queue()
+ self.received = Queue()
+
+ self.sck = spi.sck
+ self.csn = spi.csn
+ self.tx = spi.miso
+ self.rx = spi.mosi
+
+ async def send_data(self, data: int, len: int):
+ item = (data, len)
+ await self.transactions.put(item)
+
+ async def received_data(self) -> int:
+ return await self.received.get()
+
+ async def expect_transaction_in(self, max: int, unit: str = "ns"):
+ self.data = (max, unit)
+ self._wakeup.set()
+
+ async def wait_one(self):
+ self._transaction.clear()
+ await self._transaction.wait()
+
+ async def wait_all(self):
+ await self._idle.wait()
+
+ async def coroutine(self):
+ while True:
+ self._idle.set()
+ await self._wakeup.wait()
+
+ # first = await First(wakeup, csn_falling)
+ # if first == csn_falling:
+ # self.log.error("CSN fell when transaction was unexpected.")
+ # continue
+
+ self.log.info("Got new transaction expectation.")
+
+ self._wakeup.clear()
+ max, unit = self.data
+ self._idle.clear()
+
+ csn_falling = FallingEdge(self.csn)
+
+ first = await First(csn_falling, Timer(max, unit))
+
+ if first != csn_falling:
+ self.log.error(f"CSN did not fall in time ({max} {unit})!")
+ continue
+
+ # csn fell
+
+ while not self.transactions.empty():
+ # expect clock
+ data, len = await self.transactions.get()
+ total_len = len
+ self._transaction.set()
+ self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
+ data = data << 1
+ received = 0
+
+ got_sampling = False
+
+ while len > 0:
+ sampling = self.config.sampling(self.sck)
+ shifting = self.config.shifting(self.sck)
+ timeout = Timer(self.config.sck_period * 4, self.config.sck_period_unit)
+ res = await First(sampling, shifting, timeout)
+
+ if res == timeout:
+ self.log.error("Got no sck edge in time!")
+ continue
+
+ if res == shifting:
+ if got_sampling:
+ self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
+ data = data << 1
+ else:
+ got_sampling = True
+ received = received << 1
+ received |= int(self.rx.value)
+ len -= 1
+
+ # now wait for csn rising
+ timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit)
+ csn_rising = RisingEdge(self.csn)
+ res = await First(csn_rising, timeout)
+
+ if res == timeout:
+ self.log.error("Got no rising edge on csn")
+ continue
+
+ self.tx.value = cocotb.handle.Release()
+ await self.received.put(received)
+ self.log.info(f"Received {received}")
+
+ # good, continue
+ self._transaction.set()
+
+
+async def init(dut, master: int = 1, tx_en: int = 1):
+ dut._log.info("Init started!")
+ dut.miso_io.value = 0;
dut.rst_in.value = 0;
- dut.en_i.value = 1;
dut.clock_polarity_i.value = 0;
dut.clock_phase_i.value = 0;
dut.size_sel_i.value = 0;
dut.div_sel_i.value = 0;
dut.pulse_csn_i.value = 0;
dut.rx_block_on_full_i.value = 0;
- dut.rx_en_i.value = 1;
dut.rx_ready_i.value = 0;
- dut.tx_en_i.value = 1;
dut.tx_valid_i.value = 0;
- dut.clear_lost_rx_data_i.value = 1;
+ dut.clear_lost_rx_data_i.value = 0;
+
+ dut.rx_en_i.value = 1;
+ dut.tx_en_i.value = 1;
+
+ dut.master_i.value = 1;
+ dut.en_i.value = 1;
- await Timer(5, "ns")
+ await FallingEdge(dut.clk_i)
+ await FallingEdge(dut.clk_i)
+
+ # Release reset
dut.rst_in.value = 1;
- await Timer(5, "ns")
- dut.clk_i.value = 1;
- await Timer(5, "ns")
- dut.clk_i.value = 0;
+ await FallingEdge(dut.clk_i)
+ await FallingEdge(dut.clk_i)
- await Timer(5, "ns")
- dut.clk_i.value = 1;
- await Timer(5, "ns")
- dut.clk_i.value = 0;
+ dut._log.info("Init done!")
- dut.tx_valid_i.value = 1;
- dut.tx_data_i.value = 100;
+@cocotb.test()
+async def simple_test(dut):
+ clk = Clock(dut.clk_i, 5, "ns")
+ interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
+ config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns")
+ slave = SpiSlave(interface, config)
- await Timer(5, "ns")
- dut.clk_i.value = 1;
- await Timer(5, "ns")
- dut.clk_i.value = 0;
- dut.tx_valid_i.value = 0;
+ await cocotb.start(clk.start())
+
+ await init(dut)
+
+ await cocotb.start(slave.coroutine())
+
+ await slave.send_data(123, 8)
+ await slave.expect_transaction_in(15, "ns")
+
+ await FallingEdge(dut.clk_i)
+ dut.tx_valid_i.value = 1
+ dut.tx_data_i.value = 100
+
+ await FallingEdge(dut.clk_i)
+ dut.tx_valid_i.value = 0
+
+ await slave.wait_one()
+
+ assert int(dut.rx_valid_o.value) == 1
+ assert int(dut.rx_data_o.value) & 0xFF == 123
+
+ received = await slave.received_data()
+
+ assert received & 0xFF == 100
+
+ await Timer(100, "ns")
+
+# Simple one receive, transmit
+ # Check csn goes low
+ # Check 8 sck pulses
+# All clock phases and polarities
+# All sizes, divisors
+# Rx blocking - Can't go to another transmission until data confirmed.
+ # When data read a bit later, and csn pulsing is enabled, the csn should still pulse, before data are obtained
+
+# Multiple times receive, transmit without pulse
+ # All clock phases and polarities
+
+# Multiple times receive, transmit with pulse
- for cycle in range(1, 100):
- dut.clk_i.value = 1;
- await Timer(5, "ns")
- dut.clk_i.value = 0;
- await Timer(5, "ns")
+# Only transmission
+ # Rx is not valid
+# Only reception
+ # Z on mosi