from dataclasses import dataclass
from cocotb.queue import Queue
import cocotb
import cocotb.handle
from cocotb.clock import Clock
from cocotb.triggers import Trigger, First, Event, Timer, RisingEdge, FallingEdge
import logging
import random
@dataclass
class SpiInterface:
csn: cocotb.handle.LogicObject
sck: cocotb.handle.LogicObject
miso: cocotb.handle.LogicObject
mosi: cocotb.handle.LogicObject
# class SignalMonitor:
# def expect_change_to_after() -> SignalMonitor:
# def expect_stable_for() -> SignalMonitor:
# def expect_value() -> SignalMonitor:
# def step() -> SignalMonitor:
# async def coroutine():
# class ClockedSignalMonitor:
# def __init__(clock, signal, sampling = RisingEdge):
# self.signal = signal
# self.sampling = sampling
# self.events = [ [] ]
# def expect_change_to_after() -> SignalMonitor:
# def expect_stable_for() -> SignalMonitor:
# def expect_value() -> SignalMonitor:
# def set_value() -> SignalMonitor:
# def set_callback() -> SignalMonitor:
# def step() -> SignalMonitor:
# self.events.append([])
# async def coroutine():
# while True:
# await self.sampling
# if self.events.len() == 0:
# continue
# if self.events[0].len() == 0:
# continue
# events = self.events.pop(0)
# for event in events:
@dataclass
class SpiConfig:
transaction_len: int
sampling: Trigger
shifting: Trigger
sck_period: int
sck_period_unit: str
class SpiSlave:
def __init__(self, spi: SpiInterface, config: SpiConfig):
self._transaction = Event("spi_slave_transaction")
self._idle = Event("spi_idle")
self._wakeup = Event("spi_slave_wakeup")
self._transaction.clear()
self._idle.clear()
self._wakeup.clear()
self._log = logging.getLogger("cocotb.SpiSlave")
self.config = config
self.transactions = Queue()
self.received = Queue()
self.sck = spi.sck
self.csn = spi.csn
self.tx = spi.miso
self.rx = spi.mosi
async def send_data(self, data: int, len: int):
item = (data, len)
await self.transactions.put(item)
async def received_data(self) -> int:
return await self.received.get()
async def expect_transaction_in(self, max: int, unit: str = "ns"):
self.data = (max, unit)
self._wakeup.set()
async def wait_one(self):
self._transaction.clear()
await self._transaction.wait()
async def wait_all(self):
await self._idle.wait()
async def coroutine(self):
while True:
self._idle.set()
await self._wakeup.wait()
# first = await First(wakeup, csn_falling)
# if first == csn_falling:
# self.log.error("CSN fell when transaction was unexpected.")
# continue
self._log.info("Got new transaction expectation.")
self._wakeup.clear()
max, unit = self.data
self._idle.clear()
csn_falling = FallingEdge(self.csn)
first = await First(csn_falling, Timer(max, unit))
if first != csn_falling:
self._log.error(f"CSN did not fall in time ({max} {unit})!")
continue
# csn fell
first = True
while not self.transactions.empty():
# expect clock
data, len = await self.transactions.get()
total_len = len
self._transaction.set()
if first:
self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
data = data << 1
received = 0
got_sampling = False
while len > 0:
sampling = self.config.sampling(self.sck)
shifting = self.config.shifting(self.sck)
timeout = Timer(self.config.sck_period * 4, self.config.sck_period_unit)
res = await First(sampling, shifting, timeout)
if res == timeout:
self._log.error("Got no sck edge in time!")
continue
if res == shifting:
if got_sampling or not first:
self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
data = data << 1
elif res == sampling:
got_sampling = True
received = received << 1
received |= int(self.rx.value)
len -= 1
await Timer(self.config.sck_period / 4, self.config.sck_period_unit)
await self.received.put(received)
self._log.info(f"Received {received}")
first = False
# now wait for csn rising
timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit)
csn_rising = RisingEdge(self.csn)
res = await First(csn_rising, timeout)
self._log.error("TEST fs")
if res == timeout:
self._log.error("Got no rising edge on csn")
continue
self.tx.value = cocotb.handle.Release()
# good, continue
self._transaction.set()
async def init(dut, master: int = 1, tx_en: int = 1):
dut._log.info("Init started!")
dut.miso_io.value = 0;
dut.rst_in.value = 0;
dut.clock_polarity_i.value = 0;
dut.clock_phase_i.value = 0;
dut.size_sel_i.value = 0;
dut.div_sel_i.value = 0;
dut.pulse_csn_i.value = 0;
dut.rx_block_on_full_i.value = 0;
dut.rx_ready_i.value = 0;
dut.tx_valid_i.value = 0;
dut.clear_lost_rx_data_i.value = 0;
dut.rx_en_i.value = 1;
dut.tx_en_i.value = 1;
dut.master_i.value = 1;
dut.en_i.value = 1;
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
# Release reset
dut.rst_in.value = 1;
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
dut._log.info("Init done!")
class DutDriver:
def __init__(self, dut):
self.dut = dut
self._log = logging.getLogger("cocotb.DutDriver")
self._received = Queue()
self._sync = Event()
self._auto_receive = False
self._set_ready = False
async def receive_data(self):
if int(self.dut.rx_valid_o.value) != 1:
self._log.error("RX is not valid when receiving data was requested")
await FallingEdge(self.dut.clk_i)
self.dut.rx_ready_i.value = 1
data = self.dut.rx_data_o.value
await FallingEdge(self.dut.clk_i)
if int(self.dut.rx_valid_o.value) != 0:
self._log.error("RX data stayed valid after receiving!")
self.dut.rx_ready_i.value = 0
return data
async def send_data(self, data):
if int(self.dut.tx_ready_o.value) != 1:
self._log.error("TX is not ready when sending data was requested")
await FallingEdge(self.dut.clk_i)
self.dut.tx_valid_i.value = 1
self.dut.tx_data_i.value = data
await FallingEdge(self.dut.clk_i)
self.dut.tx_valid_i.value = 0
if int(self.dut.tx_ready_o.value) != 0:
self._log.error("TX stayed ready after requesting send of data!")
async def send_data_wait(self, data):
await FallingEdge(self.dut.clk_i)
self.dut.tx_valid_i.value = 1
self.dut.tx_data_i.value = data
clk_falling = FallingEdge(self.dut.clk_i)
tx_ready_falling = FallingEdge(self.dut.tx_ready_o)
res = await First(clk_falling, tx_ready_falling)
if res == clk_falling:
# TODO timeout
await RisingEdge(self.dut.tx_ready_o)
await RisingEdge(self.dut.clk_i)
# now the data were registered
# (note that it's ready, not confirmation
# so data are sampled only after it actually is ready
# that means the next clock cycle from when ready went to 1)
self.dut.tx_valid_i.value = 0
async def auto_receive(self, receive: bool = True, set_ready: bool = True):
self._auto_receive = receive
self._set_ready = set_ready
self._sync.set()
async def received_data(self):
if self._received.empty():
return None
return await self._received.get()
async def coroutine(self):
while True:
if not self._auto_receive:
self.dut.rx_ready_i.value = 0
await self._sync.wait()
self._sync.clear()
if not self._auto_receive:
continue
if self._set_ready:
self.dut.rx_ready_i.value = 1
await RisingEdge(self.dut.rx_valid_o)
if int(self.dut.rx_valid_o.value) == 1:
await self._received.put(self.dut.rx_data_o.value)
should_lose_data = not self._set_ready and self.dut.tx_valid_i.value == 1
await RisingEdge(self.dut.clk_i)
await RisingEdge(self.dut.clk_i)
if should_lose_data and int(self.dut.err_lost_rx_data_o.value) != 1:
self._log.error("Didn't get err lost rx data after rx valid")
if should_lose_data:
self.dut.clear_lost_rx_data_i.value = 1
await RisingEdge(self.dut.clk_i)
await FallingEdge(self.dut.clk_i)
self.dut.clear_lost_rx_data_i.value = 0
if int(self.dut.err_lost_rx_data_o.value) != 0:
self._log.error("Could not clear err lost rx data")
@cocotb.test()
async def single_transmit(dut):
clk = Clock(dut.clk_i, 5, "ns", impl = "py")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns")
slave = SpiSlave(interface, config)
driver = DutDriver(dut)
await cocotb.start(clk.start())
await init(dut)
await cocotb.start(slave.coroutine())
await cocotb.start(driver.coroutine())
# From slave point of view
rx = random.randint(0, 255)
tx = random.randint(0, 255)
await slave.send_data(tx, 8)
await slave.expect_transaction_in(15, "ns")
await driver.send_data(rx)
await slave.wait_all()
dut_received = await driver.receive_data()
assert int(dut_received) & 0xFF == tx
# Wait a few clocks, rx data should still stay valid!
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
await FallingEdge(dut.clk_i)
received = await slave.received_data()
assert received & 0xFF == rx
await Timer(100, "ns")
async def perform_multiple_transmits(count, dut, slave, driver):
tx_data = [random.randint(0, 255) for i in range(count)]
rx_data = [random.randint(0, 255) for i in range(count)]
for tx in tx_data:
await slave.send_data(tx, 8)
dut._log.info(f"Sending Data from slave: {tx}")
dut._log.info("To expect transaction")
await slave.expect_transaction_in(15, "ns")
for rx in rx_data:
dut._log.info(f"Sending Data from master: {rx}")
await driver.send_data_wait(rx)
await slave.wait_all()
# Checks
for tx in tx_data:
dut_received = await driver.received_data()
assert int(dut_received) & 0xFF == tx
for rx in rx_data:
received = await slave.received_data()
assert received & 0xFF == rx
@cocotb.test()
async def multiple_transmits(dut):
clk = Clock(dut.clk_i, 5, "ns")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns")
slave = SpiSlave(interface, config)
driver = DutDriver(dut)
await cocotb.start(clk.start())
await init(dut)
await cocotb.start(slave.coroutine())
await cocotb.start(driver.coroutine())
await driver.auto_receive()
count = 5
await perform_multiple_transmits(count, dut, slave, driver)
await Timer(100, "ns")
@cocotb.test()
async def lost_rx_data(dut):
clk = Clock(dut.clk_i, 5, "ns")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, RisingEdge, FallingEdge, 10, "ns")
slave = SpiSlave(interface, config)
driver = DutDriver(dut)
await cocotb.start(clk.start())
await init(dut)
await cocotb.start(slave.coroutine())
await cocotb.start(driver.coroutine())
# Do not confirm reception of data. That will auto check
# if lost is asserted and cleared appropriately.
await driver.auto_receive(True, False)
count = 5
await perform_multiple_transmits(count, dut, slave, driver)
await Timer(100, "ns")
@cocotb.test()
async def different_clock(dut):
clk = Clock(dut.clk_i, 5, "ns")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, RisingEdge, FallingEdge, 20, "ns")
slave = SpiSlave(interface, config)
driver = DutDriver(dut)
await cocotb.start(clk.start())
await init(dut)
dut.div_sel_i.value = 1 # divide by 4
await cocotb.start(slave.coroutine())
await cocotb.start(driver.coroutine())
await driver.auto_receive()
count = 5
await perform_multiple_transmits(count, dut, slave, driver)
await Timer(100, "ns")
@cocotb.test()
async def inverted_clock(dut):
clk = Clock(dut.clk_i, 5, "ns")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, RisingEdge, FallingEdge, 20, "ns")
slave = SpiSlave(interface, config)
driver = DutDriver(dut)
await cocotb.start(clk.start())
await init(dut)
dut.clock_phase_i.value = 1
dut.clock_polarity_i.value = 1
await cocotb.start(slave.coroutine())
await cocotb.start(driver.coroutine())
await driver.auto_receive()
count = 5
await perform_multiple_transmits(count, dut, slave, driver)
await Timer(100, "ns")
@cocotb.test()
async def shifted_inverted_clock(dut):
clk = Clock(dut.clk_i, 5, "ns")
interface = SpiInterface(dut.csn_io, dut.sck_io, dut.miso_io, dut.mosi_io)
config = SpiConfig(8, FallingEdge, RisingEdge, 20, "ns")
slave = SpiSlave(interface, config)
driver = DutDriver(dut)
await cocotb.start(clk.start())
await init(dut)
dut.clock_phase_i.value = 0
dut.clock_polarity_i.value = 1
await FallingEdge(dut.clk_i)
await cocotb.start(slave.coroutine())
await cocotb.start(driver.coroutine())
await driver.auto_receive()
count = 3
await perform_multiple_transmits(count, dut, slave, driver)
dut.clock_phase_i.value = 1
dut.clock_polarity_i.value = 0
await FallingEdge(dut.clk_i)
await perform_multiple_transmits(count, dut, slave, driver)
await Timer(100, "ns")
# All sizes
# Rx blocking - Can't go to another transmission until data confirmed.
# When data read a bit later, and csn pulsing is enabled, the csn should still pulse, before data are obtained
# All clock phases and polarities
# csn pulse
# rx_en off - miso should be ignored, no rx_valid is always 0. Tx works fine
# tx_en off - mosi should be Z, tx_ready is always 0. Rx works fine