import abc
import dataclasses
import typing
from abc import abstractmethod
from ._util import Assert, Copy, Iterable, MechId, MechIdTuple, Merge
from .exceptions import ModelDefinitionError
if typing.TYPE_CHECKING:
from .parameter import Parameter
[docs]@dataclasses.dataclass
class CableProperties(Copy, Merge, Assert, Iterable):
Ra: float = None
cm: float = None
"""
Axial resistivity in ohm/cm
"""
CablePropertiesDict = typing.TypedDict(
"CablePropertiesDict",
{"Ra": float, "cm": float},
total=False,
)
[docs]@dataclasses.dataclass
class Ion(Copy, Merge, Assert, Iterable):
rev_pot: float = None
int_con: float = None
ext_con: float = None
IonDict = typing.TypedDict(
"IonDict",
{"rev_pot": float, "int_con": float, "ext_con": float},
total=False,
)
[docs]class Mechanism:
def __init__(self, parameters: dict[str, float]):
super().__init__()
self.parameters = parameters
[docs] def merge(self, other):
for key, value in other.parameters.items():
self.parameters[key] = value
[docs] def copy(self):
return Mechanism(self.parameters.copy())
[docs]class Synapse(Mechanism):
mech_id: MechIdTuple
def __init__(self, parameters, mech_id: MechId):
super().__init__(parameters)
self.mech_id = to_mech_id(mech_id)
[docs] def copy(self):
return type(self)(self.parameters.copy(), to_mech_id(self.mech_id))
ExpandedSynapseDict = typing.TypedDict(
"ExpandedSynapseDict",
{"mechanism": MechId, "parameters": dict[str, float]},
total=False,
)
SynapseDict = typing.Union[
dict[str, float],
ExpandedSynapseDict,
]
[docs]def is_mech_id(mech_id):
return str(mech_id) == mech_id or (
tuple(mech_id) == mech_id
and 0 < len(mech_id) < 4
and all(str(part) == part for part in mech_id)
)
[docs]def to_mech_id(mech_id: MechId) -> MechIdTuple:
if mech_id is None:
raise ValueError("Mech id may not be None")
return (mech_id,) if not isinstance(mech_id, tuple) else tuple(mech_id)
[docs]class CableType:
cable: CableProperties
ions: dict[str, Ion]
mechs: dict[MechId, Mechanism]
synapses: dict[MechId, Synapse]
def __init__(self, cable_property_class=CableProperties):
self.cable = cable_property_class()
self.ions = {}
self.mechs = {}
self.synapses = {}
[docs] def copy(self):
def_ = type(self)()
def_.cable = self.cable.copy()
def_.ions = {k: v.copy() for k, v in self.ions.items()}
def_.mechs = {k: v.copy() for k, v in self.mechs.items()}
def_.synapses = {k: v.copy() for k, v in self.synapses.items()}
return def_
[docs] def set(self, param: "Parameter"):
if hasattr(param, "set_cable_params"):
param.set_cable_params(self.cable)
if hasattr(param, "set_mech_params"):
param.set_mech_params(self.mechs)
[docs] @classmethod
def anchor(
cls,
defs: typing.Iterable["CableType"],
synapses: dict[MechId, Synapse] = None,
use_defaults: bool = False,
ion_class=Ion,
) -> "CableType":
def_ = cls() if not use_defaults else cls.default(ion_class)
if synapses is not None:
# We need to merge the local synapses on top of the global ones,
# without mutating the global dictionary. So we:
# - Create a new cable type for the global synapses
globaldef = cls()
# - Add the synapses to it
for key, value in synapses.items():
globaldef.add_synapse(key, value)
# - Merge the local synapses over it
globaldef._mergedict(globaldef.synapses, def_.synapses)
# - Transfer the result to our def.
def_.synapses = globaldef.synapses
# Merge the definitions onto our def. Each merge overwrites our values, with the
# last item in the list having the final say.
for def_right in defs:
if def_right is None:
continue
def_.merge(def_right)
return def_
[docs] def merge(self, def_right: "CableType"):
self.cable.merge(def_right.cable)
self._mergedict(self.ions, def_right.ions)
self._mergedict(self.mechs, def_right.mechs)
self._mergedict(self.synapses, def_right.synapses)
def _mergedict(self, dself, dother):
for key, value in dother.items():
if key in dself:
dself[key].merge(dother[key])
else:
dself[key] = value.copy()
[docs] def assert_(self):
self.cable.assert_()
for ion_name, ion in self.ions.items():
try:
ion.assert_()
except ValueError as e:
raise ValueError(
f"Missing '{e.args[1]}' value in ion '{ion_name}'",
ion_name,
e.args[1],
) from None
[docs] @classmethod
def default(cls, ion_class=Ion):
default = cls()
default.cable.Ra = 35.4
default.cable.cm = 1
default.ions = default_ions_dict(ion_class)
return default
[docs] def add_ion(self, key: str, ion: Ion):
if key in self.ions:
raise KeyError(f"An ion named '{key}' already exists.")
self.ions[key] = ion
[docs] def add_mech(self, mech_id: MechId, mech: Mechanism):
if not is_mech_id(mech_id):
raise ValueError(f"'{mech_id}' is not a valid mechanism id.")
if mech_id in self.mechs:
raise KeyError(f"A mechanism with id '{mech_id}' already exists.")
self.mechs[mech_id] = mech
[docs] def add_synapse(self, label: typing.Union[str, MechId], synapse: Synapse):
mech_id = synapse.mech_id or to_mech_id(label)
if not is_mech_id(mech_id):
raise ValueError(f"'{mech_id}' is not a valid mechanism id.")
if label in self.synapses:
raise KeyError(f"A synapse with label '{label}' already exists.")
self.synapses[label] = synapse
CableTypeDict = typing.TypedDict(
"CableTypeDict",
{
"cable": CablePropertiesDict,
"ions": dict[str, IonDict],
"mechanisms": dict[MechId, dict[str, float]],
"synapses": dict[MechId, SynapseDict],
},
total=False,
)
[docs]class default_ions_dict(dict):
def __init__(self, ion_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ion_class = ion_class
def _make_defaults(self):
self._defaults = {
"na": self._ion_class(rev_pot=50.0, int_con=10.0, ext_con=140.0),
"k": self._ion_class(rev_pot=-77.0, int_con=54.4, ext_con=2.5),
"ca": self._ion_class(
rev_pot=132.4579341637009, int_con=5e-05, ext_con=2.0
),
"h": self._ion_class(rev_pot=0.0, int_con=1.0, ext_con=1.0),
}
def __setitem__(self, key, ion):
if key not in self:
if not hasattr(self, "_defaults"):
if not hasattr(self, "_ion_class"):
self._ion_class = type(ion)
self._make_defaults()
value = self._defaults[key].copy()
# Do a criss-cross merge to merge defaults into the original ion object
value.merge(ion)
ion.merge(value)
super().__setitem__(key, ion)
CT = typing.TypeVar("CT", bound=CableType)
CP = typing.TypeVar("CP", bound=CableProperties)
I = typing.TypeVar("I", bound=Ion)
M = typing.TypeVar("M", bound=Mechanism)
S = typing.TypeVar("S", bound=Synapse)
[docs]class Definition(typing.Generic[CT, CP, I, M, S], abc.ABC):
@classmethod
@property
@abstractmethod
def cable_type_class(cls) -> typing.Type[CT]:
pass
@classmethod
@property
@abstractmethod
def cable_properties_class(cls) -> typing.Type[CP]:
pass
@classmethod
@property
@abstractmethod
def ion_class(cls) -> typing.Type[I]:
pass
@classmethod
@property
@abstractmethod
def mechanism_class(cls) -> typing.Type[M]:
pass
@classmethod
@property
@abstractmethod
def synapse_class(cls) -> typing.Type[S]:
pass
def __init__(self, use_defaults=False):
self._cable_types: dict[str, CT] = {}
self._synapse_types: dict[MechId, S] = {}
self.use_defaults = use_defaults
[docs] def copy(self):
model = type(self)(self.use_defaults)
for label, def_ in self._cable_types.items():
model.add_cable_type(label, def_.copy())
for label, def_ in self._synapse_types.items():
model.add_synapse_type(label, def_)
return model
[docs] def get_cable_types(self) -> dict[str, CT]:
return {k: v.copy() for k, v in self._cable_types.items()}
[docs] def get_synapse_types(self) -> dict[str, S]:
return {k: v.copy() for k, v in self._synapse_types.items()}
[docs] def add_cable_type(self, label: str, def_: CT):
if label in self._cable_types:
raise KeyError(f"Cable type {label} already exists.")
self._cable_types[label] = def_
[docs] def add_synapse_type(self, label: typing.Union[str, MechId], synapse: S):
mech_id = synapse.mech_id or to_mech_id(label)
if not is_mech_id(mech_id):
raise ValueError(f"'{mech_id}' is not a valid synapse mechanism.")
if label in self._synapse_types:
raise KeyError(f"Synapse type {label} already exists.")
self._synapse_types[label] = synapse
[docs]class ModelDefinition(Definition[CableType, CableProperties, Ion, Mechanism, Synapse]):
[docs] @classmethod
@property
def cable_type_class(cls):
return CableType
[docs] @classmethod
@property
def cable_properties_class(cls):
return CableProperties
[docs] @classmethod
@property
def ion_class(cls):
return Ion
[docs] @classmethod
@property
def mechanism_class(cls):
return Mechanism
[docs] @classmethod
@property
def synapse_class(cls):
return Synapse
ModelDefinitionDict = typing.TypedDict(
"ModelDefinitionDict",
{
"cable_types": dict[str, CableTypeDict],
"synapse_types": dict[MechId, SynapseDict],
},
total=False,
)
@typing.overload
def define_model(
template: ModelDefinition,
definition: ModelDefinitionDict,
/,
use_defaults: bool = ...,
) -> ModelDefinition: ...
@typing.overload
def define_model(
definition: ModelDefinitionDict, /, use_defaults: bool = ...
) -> ModelDefinition: ...
[docs]def define_model(templ_or_def, def_dict=None, /, use_defaults=False) -> ModelDefinition:
if def_dict is None:
model = _parse_dict_def(ModelDefinition, templ_or_def)
else:
model = templ_or_def.copy()
model.merge(_parse_dict_def(ModelDefinition, def_dict))
model.use_defaults = use_defaults
return model
D = typing.TypeVar("D", bound=Definition)
def _parse_dict_def(cls: typing.Type[D], def_dict: ModelDefinitionDict) -> D:
model = cls()
for label, def_input in def_dict.get("cable_types", {}).items():
ct = _parse_cable_type(cls, def_input)
model.add_cable_type(label, ct)
for label, def_input in def_dict.get("synapse_types", {}).items():
st = _parse_synapse_def(cls, label, def_input)
model.add_synapse_type(label, st)
return model
def _parse_cable_type(cls: typing.Type[Definition], cable_dict: CableTypeDict):
try:
def_ = cls.cable_type_class(cls.cable_properties_class)
def_.cable = cls.cable_properties_class(**cable_dict.get("cable", {}))
for k, v in cable_dict.get("ions", {}).items():
parsed = _parse_ion_def(cls, v)
def_.add_ion(k, parsed)
for mech_id, v in cable_dict.get("mechanisms", {}).items():
def_.add_mech(mech_id, _parse_mech_def(cls, v))
for label, v in cable_dict.get("synapses", {}).items():
def_.add_synapse(label, _parse_synapse_def(cls, label, v))
return def_
except Exception:
raise ModelDefinitionError(
f"{cable_dict} is not a valid cable type definition."
)
def _parse_ion_def(cls: typing.Type[Definition], ion_dict: IonDict):
try:
return cls.ion_class(**ion_dict)
except Exception:
raise ModelDefinitionError(f"{ion_dict} is not a valid ion definition.")
def _parse_mech_def(cls: typing.Type[Definition], mech_dict: dict[str, float]):
try:
mech = cls.mechanism_class(mech_dict.copy())
return mech
except Exception:
raise ModelDefinitionError(f"{mech_dict} is not a valid mechanism definition.")
def _parse_synapse_def(cls: typing.Type[Definition], key, synapse_dict: SynapseDict):
try:
if "mechanism" in synapse_dict:
# If `mechanism` is specified, it must be an expanded dict
synapse_dict: ExpandedSynapseDict
synapse = cls.synapse_class(
# And if no parameters are given, set no parameters
synapse_dict.get("parameters", {}).copy(),
synapse_dict["mechanism"],
)
else:
# Otherwise, unless the key `parameters` is given, assume it's short form
synapse = cls.synapse_class(
# And treat all given dict items as parameters
synapse_dict.get("parameters", synapse_dict).copy(),
key,
)
return synapse
except Exception:
raise ModelDefinitionError(f"{synapse_dict} is not a valid synapse definition.")
[docs]class mechdict(dict):
def __getitem__(self, item):
return super().__getitem__((item,) if isinstance(item, str) else item)
def __setitem__(self, key, value):
return super().__setitem__((key,) if isinstance(key, str) else key, value)