import cocotb
import cocotb.handle
from dataclasses import dataclass
from cocotb.queue import Queue
from cocotb.clock import Clock
from cocotb.triggers import Trigger, First, Event, Timer, Edge, RisingEdge, FallingEdge
from cocotb.utils import get_sim_time
import logging
@dataclass
class SpiInterface:
csn: cocotb.handle.LogicObject
sck: cocotb.handle.LogicObject
miso: cocotb.handle.LogicObject
mosi: cocotb.handle.LogicObject
# class SignalMonitor:
# def expect_change_to_after() -> SignalMonitor:
# def expect_stable_for() -> SignalMonitor:
# def expect_value() -> SignalMonitor:
# def step() -> SignalMonitor:
# async def coroutine():
# class ClockedSignalMonitor:
# def __init__(clock, signal, sampling = RisingEdge):
# self.signal = signal
# self.sampling = sampling
# self.events = [ [] ]
# def expect_change_to_after() -> SignalMonitor:
# def expect_stable_for() -> SignalMonitor:
# def expect_value() -> SignalMonitor:
# def set_value() -> SignalMonitor:
# def set_callback() -> SignalMonitor:
# def step() -> SignalMonitor:
# self.events.append([])
# async def coroutine():
# while True:
# await self.sampling
# if self.events.len() == 0:
# continue
# if self.events[0].len() == 0:
# continue
# events = self.events.pop(0)
# for event in events:
@dataclass
class SpiConfig:
transaction_len: int
sampling: Trigger
shifting: Trigger
sck_period: int
sck_period_unit: str
csn_pulse: bool = False
clock_polarity: int = 0
class SpiSlave:
def __init__(self, spi: SpiInterface, config: SpiConfig):
self._transaction = Event("spi_slave_transaction")
self._idle = Event("spi_idle")
self._wakeup = Event("spi_slave_wakeup")
self._transaction.clear()
self._idle.clear()
self._wakeup.clear()
self._log = logging.getLogger("cocotb.SpiSlave")
self.config = config
self.transactions = Queue()
self.received = Queue()
self.sck = spi.sck
self.csn = spi.csn
self.tx = spi.miso
self.rx = spi.mosi
async def send_data(self, data: int, len: int):
item = (data, len)
await self.transactions.put(item)
async def received_data(self) -> int:
return await self.received.get()
async def expect_transaction_in(self, max: int, unit: str = "ns"):
self.data = (max, unit)
self._wakeup.set()
async def wait_one(self):
self._transaction.clear()
await self._transaction.wait()
async def wait_all(self):
await self._idle.wait()
async def _check_sck_period(self):
await Edge(self.sck)
while True:
sim_time = get_sim_time(self.config.sck_period_unit)
await Edge(self.sck)
next_sim_time = get_sim_time(self.config.sck_period_unit)
if next_sim_time - sim_time > 0 and (next_sim_time - sim_time) * 1.0005 < self.config.sck_period / 2:
raise Exception(f"The sck pulse is too narrow! (was: {next_sim_time - sim_time} {self.config.sck_period_unit}, expected: {self.config.sck_period / 2} {self.config.sck_period_unit})")
async def coroutine(self):
await cocotb.start(self._check_sck_period())
while True:
if self.transactions.empty():
self._idle.set()
await self._wakeup.wait()
# first = await First(wakeup, csn_falling)
# if first == csn_falling:
# self.log.error("CSN fell when transaction was unexpected.")
# continue
self._log.info("Got new transaction expectation.")
self._wakeup.clear()
self._idle.clear()
max, unit = self.data
csn_falling = FallingEdge(self.csn)
first = await First(csn_falling, Timer(max, unit))
if first != csn_falling:
self._log.error(f"CSN did not fall in time ({max} {unit})!")
raise Exception(f"CSN did not fall in time ({max} {unit})!")
continue
# csn fell
if int(self.sck.value) != self.config.clock_polarity:
raise Exception("The clock is not at correct polarity after CSN falling!")
first = True
while not self.transactions.empty():
# expect clock
data, len = await self.transactions.get()
total_len = len
self._transaction.set()
if first:
self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
data = data << 1
received = 0
got_sampling = False
while len > 0:
sampling = self.config.sampling(self.sck)
shifting = self.config.shifting(self.sck)
csn_rising = RisingEdge(self.csn)
timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit)
res = await First(sampling, shifting, csn_rising, timeout)
if res == csn_rising:
raise Exception("CSN rising too soon!")
if res == timeout:
self._log.error("Got no sck edge in time!")
raise Exception("Got not sck edge in time")
continue
if res == shifting:
if got_sampling or not first:
self.tx.value = cocotb.handle.Force((data >> (total_len - 1)) & 1)
data = data << 1
elif res == sampling:
got_sampling = True
received = received << 1
try: # if z or something, just leave it be
received |= int(self.rx.value)
except Exception:
pass
len -= 1
await Timer(1, "ns")
await self.received.put(received)
self._log.info(f"Received {received}")
first = False
if self.config.csn_pulse:
break
if int(self.sck.value) != self.config.clock_polarity:
sampling = self.config.sampling(self.sck)
csn_rising = RisingEdge(self.csn)
timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit)
res = await First(shifting, csn_rising, timeout)
if res == csn_rising:
raise Exception("CSN rising too soon!")
if res == timeout:
raise Exception("No sampling edge nor csn rising!")
# now wait for csn rising
sim_time = get_sim_time(self.config.sck_period_unit)
while sim_time == get_sim_time(self.config.sck_period_unit):
timeout = Timer(self.config.sck_period * 2, self.config.sck_period_unit)
sck_edge = self.config.sampling(self.sck)
csn_rising = RisingEdge(self.csn)
res = await First(csn_rising, sck_edge, timeout)
if res == timeout:
self._log.error("Got no rising edge on csn")
raise Exception("Got no rising edge on csn")
continue
elif res == sck_edge:
self._log.error("Got sck edge when csn rising was expected")
raise Exception("Got sck edge when csn rising was expected")
self._log.info("csn is rising")
if int(self.sck.value) != self.config.clock_polarity:
raise Exception("The clock is not at correct polarity after CSN rising!")
self.tx.value = cocotb.handle.Release()
# good, continue
self._transaction.set()