Source code for arborize.definitions

import abc
import dataclasses
import typing
from abc import abstractmethod

from ._util import Assert, Copy, Iterable, MechId, MechIdTuple, Merge
from .exceptions import ModelDefinitionError

if typing.TYPE_CHECKING:
    from .parameter import Parameter


[docs]@dataclasses.dataclass class CableProperties(Copy, Merge, Assert, Iterable): Ra: float = None cm: float = None """ Axial resistivity in ohm/cm """
CablePropertiesDict = typing.TypedDict( "CablePropertiesDict", {"Ra": float, "cm": float}, total=False, )
[docs]@dataclasses.dataclass class Ion(Copy, Merge, Assert, Iterable): rev_pot: float = None int_con: float = None ext_con: float = None
IonDict = typing.TypedDict( "IonDict", {"rev_pot": float, "int_con": float, "ext_con": float}, total=False, )
[docs]class Mechanism: def __init__(self, parameters: dict[str, float]): super().__init__() self.parameters = parameters
[docs] def merge(self, other): for key, value in other.parameters.items(): self.parameters[key] = value
[docs] def copy(self): return Mechanism(self.parameters.copy())
[docs]class Synapse(Mechanism): mech_id: MechIdTuple def __init__(self, parameters, mech_id: MechId): super().__init__(parameters) self.mech_id = to_mech_id(mech_id)
[docs] def copy(self): return type(self)(self.parameters.copy(), to_mech_id(self.mech_id))
ExpandedSynapseDict = typing.TypedDict( "ExpandedSynapseDict", {"mechanism": MechId, "parameters": dict[str, float]}, total=False, ) SynapseDict = typing.Union[ dict[str, float], ExpandedSynapseDict, ]
[docs]def is_mech_id(mech_id): return str(mech_id) == mech_id or ( tuple(mech_id) == mech_id and 0 < len(mech_id) < 4 and all(str(part) == part for part in mech_id) )
[docs]def to_mech_id(mech_id: MechId) -> MechIdTuple: if mech_id is None: raise ValueError("Mech id may not be None") return (mech_id,) if not isinstance(mech_id, tuple) else tuple(mech_id)
[docs]class CableType: cable: CableProperties ions: dict[str, Ion] mechs: dict[MechId, Mechanism] synapses: dict[MechId, Synapse] def __init__(self, cable_property_class=CableProperties): self.cable = cable_property_class() self.ions = {} self.mechs = {} self.synapses = {}
[docs] def copy(self): def_ = type(self)() def_.cable = self.cable.copy() def_.ions = {k: v.copy() for k, v in self.ions.items()} def_.mechs = {k: v.copy() for k, v in self.mechs.items()} def_.synapses = {k: v.copy() for k, v in self.synapses.items()} return def_
[docs] def set(self, param: "Parameter"): if hasattr(param, "set_cable_params"): param.set_cable_params(self.cable) if hasattr(param, "set_mech_params"): param.set_mech_params(self.mechs)
[docs] @classmethod def anchor( cls, defs: typing.Iterable["CableType"], synapses: dict[MechId, Synapse] = None, use_defaults: bool = False, ion_class=Ion, ) -> "CableType": def_ = cls() if not use_defaults else cls.default(ion_class) if synapses is not None: # We need to merge the local synapses on top of the global ones, # without mutating the global dictionary. So we: # - Create a new cable type for the global synapses globaldef = cls() # - Add the synapses to it for key, value in synapses.items(): globaldef.add_synapse(key, value) # - Merge the local synapses over it globaldef._mergedict(globaldef.synapses, def_.synapses) # - Transfer the result to our def. def_.synapses = globaldef.synapses # Merge the definitions onto our def. Each merge overwrites our values, with the # last item in the list having the final say. for def_right in defs: if def_right is None: continue def_.merge(def_right) return def_
[docs] def merge(self, def_right: "CableType"): self.cable.merge(def_right.cable) self._mergedict(self.ions, def_right.ions) self._mergedict(self.mechs, def_right.mechs) self._mergedict(self.synapses, def_right.synapses)
def _mergedict(self, dself, dother): for key, value in dother.items(): if key in dself: dself[key].merge(dother[key]) else: dself[key] = value.copy()
[docs] def assert_(self): self.cable.assert_() for ion_name, ion in self.ions.items(): try: ion.assert_() except ValueError as e: raise ValueError( f"Missing '{e.args[1]}' value in ion '{ion_name}'", ion_name, e.args[1], ) from None
[docs] @classmethod def default(cls, ion_class=Ion): default = cls() default.cable.Ra = 35.4 default.cable.cm = 1 default.ions = default_ions_dict(ion_class) return default
[docs] def add_ion(self, key: str, ion: Ion): if key in self.ions: raise KeyError(f"An ion named '{key}' already exists.") self.ions[key] = ion
[docs] def add_mech(self, mech_id: MechId, mech: Mechanism): if not is_mech_id(mech_id): raise ValueError(f"'{mech_id}' is not a valid mechanism id.") if mech_id in self.mechs: raise KeyError(f"A mechanism with id '{mech_id}' already exists.") self.mechs[mech_id] = mech
[docs] def add_synapse(self, label: typing.Union[str, MechId], synapse: Synapse): mech_id = synapse.mech_id or to_mech_id(label) if not is_mech_id(mech_id): raise ValueError(f"'{mech_id}' is not a valid mechanism id.") if label in self.synapses: raise KeyError(f"A synapse with label '{label}' already exists.") self.synapses[label] = synapse
CableTypeDict = typing.TypedDict( "CableTypeDict", { "cable": CablePropertiesDict, "ions": dict[str, IonDict], "mechanisms": dict[MechId, dict[str, float]], "synapses": dict[MechId, SynapseDict], }, total=False, )
[docs]class default_ions_dict(dict): def __init__(self, ion_class, *args, **kwargs): super().__init__(*args, **kwargs) self._ion_class = ion_class def _make_defaults(self): self._defaults = { "na": self._ion_class(rev_pot=50.0, int_con=10.0, ext_con=140.0), "k": self._ion_class(rev_pot=-77.0, int_con=54.4, ext_con=2.5), "ca": self._ion_class( rev_pot=132.4579341637009, int_con=5e-05, ext_con=2.0 ), "h": self._ion_class(rev_pot=0.0, int_con=1.0, ext_con=1.0), } def __setitem__(self, key, ion): if key not in self: if not hasattr(self, "_defaults"): if not hasattr(self, "_ion_class"): self._ion_class = type(ion) self._make_defaults() value = self._defaults[key].copy() # Do a criss-cross merge to merge defaults into the original ion object value.merge(ion) ion.merge(value) super().__setitem__(key, ion)
CT = typing.TypeVar("CT", bound=CableType) CP = typing.TypeVar("CP", bound=CableProperties) I = typing.TypeVar("I", bound=Ion) M = typing.TypeVar("M", bound=Mechanism) S = typing.TypeVar("S", bound=Synapse)
[docs]class Definition(typing.Generic[CT, CP, I, M, S], abc.ABC): @classmethod @property @abstractmethod def cable_type_class(cls) -> typing.Type[CT]: pass @classmethod @property @abstractmethod def cable_properties_class(cls) -> typing.Type[CP]: pass @classmethod @property @abstractmethod def ion_class(cls) -> typing.Type[I]: pass @classmethod @property @abstractmethod def mechanism_class(cls) -> typing.Type[M]: pass @classmethod @property @abstractmethod def synapse_class(cls) -> typing.Type[S]: pass def __init__(self, use_defaults=False): self._cable_types: dict[str, CT] = {} self._synapse_types: dict[MechId, S] = {} self.use_defaults = use_defaults
[docs] def copy(self): model = type(self)(self.use_defaults) for label, def_ in self._cable_types.items(): model.add_cable_type(label, def_.copy()) for label, def_ in self._synapse_types.items(): model.add_synapse_type(label, def_) return model
[docs] def get_cable_types(self) -> dict[str, CT]: return {k: v.copy() for k, v in self._cable_types.items()}
[docs] def get_synapse_types(self) -> dict[str, S]: return {k: v.copy() for k, v in self._synapse_types.items()}
[docs] def add_cable_type(self, label: str, def_: CT): if label in self._cable_types: raise KeyError(f"Cable type {label} already exists.") self._cable_types[label] = def_
[docs] def add_synapse_type(self, label: typing.Union[str, MechId], synapse: S): mech_id = synapse.mech_id or to_mech_id(label) if not is_mech_id(mech_id): raise ValueError(f"'{mech_id}' is not a valid synapse mechanism.") if label in self._synapse_types: raise KeyError(f"Synapse type {label} already exists.") self._synapse_types[label] = synapse
[docs]class ModelDefinition(Definition[CableType, CableProperties, Ion, Mechanism, Synapse]):
[docs] @classmethod @property def cable_type_class(cls): return CableType
[docs] @classmethod @property def cable_properties_class(cls): return CableProperties
[docs] @classmethod @property def ion_class(cls): return Ion
[docs] @classmethod @property def mechanism_class(cls): return Mechanism
[docs] @classmethod @property def synapse_class(cls): return Synapse
ModelDefinitionDict = typing.TypedDict( "ModelDefinitionDict", { "cable_types": dict[str, CableTypeDict], "synapse_types": dict[MechId, SynapseDict], }, total=False, ) @typing.overload def define_model( template: ModelDefinition, definition: ModelDefinitionDict, /, use_defaults: bool = ..., ) -> ModelDefinition: ... @typing.overload def define_model( definition: ModelDefinitionDict, /, use_defaults: bool = ... ) -> ModelDefinition: ...
[docs]def define_model(templ_or_def, def_dict=None, /, use_defaults=False) -> ModelDefinition: if def_dict is None: model = _parse_dict_def(ModelDefinition, templ_or_def) else: model = templ_or_def.copy() model.merge(_parse_dict_def(ModelDefinition, def_dict)) model.use_defaults = use_defaults return model
D = typing.TypeVar("D", bound=Definition) def _parse_dict_def(cls: typing.Type[D], def_dict: ModelDefinitionDict) -> D: model = cls() for label, def_input in def_dict.get("cable_types", {}).items(): ct = _parse_cable_type(cls, def_input) model.add_cable_type(label, ct) for label, def_input in def_dict.get("synapse_types", {}).items(): st = _parse_synapse_def(cls, label, def_input) model.add_synapse_type(label, st) return model def _parse_cable_type(cls: typing.Type[Definition], cable_dict: CableTypeDict): try: def_ = cls.cable_type_class(cls.cable_properties_class) def_.cable = cls.cable_properties_class(**cable_dict.get("cable", {})) for k, v in cable_dict.get("ions", {}).items(): parsed = _parse_ion_def(cls, v) def_.add_ion(k, parsed) for mech_id, v in cable_dict.get("mechanisms", {}).items(): def_.add_mech(mech_id, _parse_mech_def(cls, v)) for label, v in cable_dict.get("synapses", {}).items(): def_.add_synapse(label, _parse_synapse_def(cls, label, v)) return def_ except Exception: raise ModelDefinitionError( f"{cable_dict} is not a valid cable type definition." ) def _parse_ion_def(cls: typing.Type[Definition], ion_dict: IonDict): try: return cls.ion_class(**ion_dict) except Exception: raise ModelDefinitionError(f"{ion_dict} is not a valid ion definition.") def _parse_mech_def(cls: typing.Type[Definition], mech_dict: dict[str, float]): try: mech = cls.mechanism_class(mech_dict.copy()) return mech except Exception: raise ModelDefinitionError(f"{mech_dict} is not a valid mechanism definition.") def _parse_synapse_def(cls: typing.Type[Definition], key, synapse_dict: SynapseDict): try: if "mechanism" in synapse_dict: # If `mechanism` is specified, it must be an expanded dict synapse_dict: ExpandedSynapseDict synapse = cls.synapse_class( # And if no parameters are given, set no parameters synapse_dict.get("parameters", {}).copy(), synapse_dict["mechanism"], ) else: # Otherwise, unless the key `parameters` is given, assume it's short form synapse = cls.synapse_class( # And treat all given dict items as parameters synapse_dict.get("parameters", synapse_dict).copy(), key, ) return synapse except Exception: raise ModelDefinitionError(f"{synapse_dict} is not a valid synapse definition.")
[docs]class mechdict(dict): def __getitem__(self, item): return super().__getitem__((item,) if isinstance(item, str) else item) def __setitem__(self, key, value): return super().__setitem__((key,) if isinstance(key, str) else key, value)