from dataclasses import dataclass
from cocotb.queue import Queue
import cocotb
import cocotb.handle
from cocotb.clock import Clock
from cocotb.triggers import Trigger, First, Event, Timer, RisingEdge, FallingEdge
import logging
@dataclass
class SpiInterface:
csn: cocotb.handle.LogicObject
sck: cocotb.handle.LogicObject
miso: cocotb.handle.LogicObject
mosi: cocotb.handle.LogicObject
# class SignalMonitor:
# def expect_change_to_after() -> SignalMonitor:
# def expect_stable_for() -> SignalMonitor:
# def expect_value() -> SignalMonitor:
# def step() -> SignalMonitor:
# async def coroutine():
# class ClockedSignalMonitor:
# def __init__(clock, signal, sampling = RisingEdge):
# self.signal = signal
# self.sampling = sampling
# self.events = [ [] ]
# def expect_change_to_after() -> SignalMonitor:
# def expect_stable_for() -> SignalMonitor:
# def expect_value() -> SignalMonitor:
# def set_value() -> SignalMonitor:
# def set_callback() -> SignalMonitor:
# def step() -> SignalMonitor:
# self.events.append([])
# async def coroutine():
# while True:
# await self.sampling
# if self.events.len() == 0:
# continue
# if self.events[0].len() == 0:
# continue
# events = self.events.pop(0)
# for event in events:
@dataclass
class SpiConfig:
transaction_len: int
sampling: Trigger
shifting: Trigger
sck_period: int
sck_period_unit: str
class SpiSlave:
def __init__(self, spi: SpiInterface, config: SpiConfig):
self._transaction = Event("spi_slave_transaction")
self._idle = Event("spi_idle")
self._wakeup = Event("spi_slave_wakeup")
self._transaction.clear()
self._idle.clear()
self._wakeup.clear()
self.log = logging.getLogger("SpiSlave")
self.config = config
self.transactions = Queue()
self.received = Queue()
self.sck = spi.sck
self.csn = spi.csn
self.tx = spi.miso
self.rx = spi.mosi
async def send_data(self, data: int, len: int):
item = (data, len)
await self.transactions.put(item)
async def received_data(self) -> int:
return await self.received.get()
async def expect_transaction_in(self, max: int, unit: str = "ns"):
self.data = (max, unit)
self._wakeup.set()
async def wait_one(self):
self._transaction.clear()
await self._transaction.wait()
async def wait_all(self):
await self._idle.wait()
async def coroutine(self):
while True:
self._idle.set()
await self._wakeup.wait()
# first = await First(wakeup, csn_falling)
# if first == csn_falling:
# self.log.error("CSN fell when transaction was unexpected.")
# continue
self.log.info("Got new transaction expectation.")
self._wakeup.clear()
max, unit = self.data
self._idle.clear()
csn_falling = FallingEdge(self.csn)
first = await First(csn_falling, Timer(max, unit))
if first != csn_falling:
self.log.error(f"CSN did not fall in time ({max} {unit})!")
continue
# csn fell
while not self.transactions.empty():
# expect clock
data, len = await self.transactions.get()
total_len = len
self._transaction.set()
self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
data = data << 1
received = 0
got_sampling = False
while len > 0:
sampling = self.config.sampling(self.sck)
shifting = self.config.shifting(self.sck)
timeout = Timer(self.config.sck_period * 4, self.config.sck_period_unit)
res = await First(sampling, shifting, timeout)
if res == timeout:
self.log.error("Got no sck edge in time!")
continue
if res == shifting:
if got_sampling:
self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
data = data << 1
else:
got_sampling = True
received = received << 1
received |= int(self.rx.value)
len -= 1
# now wait for csn rising
timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit)
csn_rising = RisingEdge(self.csn)
res = await First(csn_rising, timeout)
if res == timeout:
self.log.error("Got no rising edge on csn")
continue
self.tx.value = cocotb.handle.Release()
await self.received.put(received)
self.log.info(f"Received {received}")
# good, continue
self._transaction.set()
async def init(dut, master: int = 1, tx_en: int = 1):
dut._log.info("Init started!")
dut.miso_io.value = 0;
dut.rst_in.value = 0;
dut.clock_polarity_i.value = 0;
dut.clock_phase_i.value = 0;
dut.size_sel_i.value = 0;
dut.div_sel_i.value = 0;
dut.pulse_csn_i.value = 0;
dut.rx_block_on_full_i.value = 0;
dut.rx_ready_i.value = 0;
dut.tx_valid_i.value = 0;
dut.clear_lost_rx_data_i.value = 0;
dut.rx_en_i.value = 1;
dut.tx_en_i.value = 1;
dut.master_i.value = 1;
dut.en_i.value = 1;
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
# Release reset
dut.rst_in.value = 1;
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
dut._log.info("Init done!")
@cocotb.test()
async def simple_test(dut):
clk = Clock(dut.clk_i, 5, "ns")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns")
slave = SpiSlave(interface, config)
await cocotb.start(clk.start())
await init(dut)
await cocotb.start(slave.coroutine())
await slave.send_data(123, 8)
await slave.expect_transaction_in(15, "ns")
await FallingEdge(dut.clk_i)
dut.tx_valid_i.value = 1
dut.tx_data_i.value = 100
await FallingEdge(dut.clk_i)
dut.tx_valid_i.value = 0
await slave.wait_one()
assert int(dut.rx_valid_o.value) == 1
assert int(dut.rx_data_o.value) & 0xFF == 123
received = await slave.received_data()
assert received & 0xFF == 100
await Timer(100, "ns")
# Simple one receive, transmit
# Check csn goes low
# Check 8 sck pulses
# All clock phases and polarities
# All sizes, divisors
# Rx blocking - Can't go to another transmission until data confirmed.
# When data read a bit later, and csn pulsing is enabled, the csn should still pulse, before data are obtained
# Multiple times receive, transmit without pulse
# All clock phases and polarities
# Multiple times receive, transmit with pulse
# Only transmission
# Rx is not valid
# Only reception
# Z on mosi