Source code for arborize.constraints

import dataclasses
import itertools
import typing

from ._util import MechId
from .definitions import (
    CableProperties,
    CableType,
    Definition,
    Ion,
    Mechanism,
    Synapse,
    _parse_dict_def,
)


class Constraint:
    def __init__(self):
        self._upper = None
        self._lower = None
        self._tolerance = None

    @property
    def tolerance(self):
        return self._tolerance

    @property
    def lower(self):
        value = self._lower
        if self.tolerance is not None:
            value *= 1 - self.tolerance
        return value

    @lower.setter
    def lower(self, value: float):
        self._lower = value

    @property
    def upper(self):
        value = self._upper
        if self.tolerance is not None:
            value *= 1 - self.tolerance
        return value

    @upper.setter
    def upper(self, value: float):
        self._upper = value

    @classmethod
    def from_value(cls, value: "ConstraintValue") -> "Constraint":
        if isinstance(value, Constraint):
            return value
        elif isinstance(value, (list, tuple)):
            constraint = cls()
            constraint.lower = value[0]
            constraint.upper = value[1]
        else:
            constraint = cls()
            constraint.upper = value
            constraint.lower = value
        return constraint

    def set_tolerance(self, tolerance=None):
        self._tolerance = tolerance
        return self


ConstraintValue = typing.Union[Constraint, float, tuple[float, float], list[float]]


@dataclasses.dataclass
class CablePropertyConstraints(CableProperties):
    Ra: Constraint
    """
    Axial resistivity in ohm/cm
    """
    cm: Constraint
    """
    Membrane conductance
    """

    def __post_init__(self):
        for field in dataclasses.fields(self):
            _convert_field(self, field)


CablePropertyConstraintsDict = typing.TypedDict(
    "CablePropertyConstraintsDict",
    {"Ra": ConstraintValue, "cm": ConstraintValue},
    total=False,
)


@dataclasses.dataclass
class IonConstraints(Ion):
    rev_pot: Constraint
    int_con: Constraint
    ext_con: Constraint

    def __post_init__(self):
        for field in dataclasses.fields(self):
            _convert_field(self, field)


IonConstraintsDict = typing.TypedDict(
    "IonConstraintsDict",
    {
        "rev_pot": ConstraintValue,
        "int_con": ConstraintValue,
        "ext_con": ConstraintValue,
    },
    total=False,
)


class MechanismConstraints(Mechanism):
    parameters: dict[str, Constraint]

    def __init__(self, parameters: dict[str, ConstraintValue]):
        super().__init__({k: Constraint.from_value(v) for k, v in parameters.items()})


class SynapseConstraints(Synapse, MechanismConstraints):
    pass


SynapseConstraintsDict = typing.Union[
    dict[str, ConstraintValue],
    typing.TypedDict(
        "SynapseConstraintsDict",
        {"mechanism": MechId, "parameters": dict[str, ConstraintValue]},
        total=False,
    ),
]


class CableTypeConstraints(CableType):
    cable: CablePropertyConstraints
    mechs: dict[MechId, MechanismConstraints]
    ions: dict[str, IonConstraints]
    synapses: dict[str, SynapseConstraints]

    @classmethod
    def default(cls, ion_class=IonConstraints):
        default = super().default(ion_class=ion_class)
        for field in dataclasses.fields(default.cable):
            setattr(
                default.cable,
                field.name,
                Constraint.from_value(getattr(default.cable, field.name)),
            )
        return default


CableTypeConstraintsDict = typing.TypedDict(
    "CableTypeConstraintsDict",
    {
        "cable": CablePropertyConstraintsDict,
        "mechanisms": dict[MechId, dict[str, ConstraintValue]],
        "ions": dict[str, IonConstraintsDict],
        "synapses": dict[str, SynapseConstraintsDict],
    },
    total=False,
)


class ConstraintsDefinition(
    Definition[
        CableTypeConstraints,
        CablePropertyConstraints,
        IonConstraints,
        MechanismConstraints,
        SynapseConstraints,
    ]
):
    @classmethod
    @property
    def cable_type_class(cls):
        return CableTypeConstraints

    @classmethod
    @property
    def cable_properties_class(cls):
        return CablePropertyConstraints

    @classmethod
    @property
    def ion_class(cls):
        return IonConstraints

    @classmethod
    @property
    def mechanism_class(cls):
        return MechanismConstraints

    @classmethod
    @property
    def synapse_class(cls):
        return SynapseConstraints

    def set_tolerance(self, tolerance=None):
        for syn in self._synapse_types.values():
            for p in syn.parameters.values():
                p.set_tolerance(tolerance)

        for ct in self._cable_types.values():
            for field in dataclasses.fields(ct.cable):
                getattr(ct.cable, field.name).set_tolerance(tolerance)
            for ion in ct.ions.values():
                for field in dataclasses.fields(ion):
                    getattr(ion, field.name).set_tolerance(tolerance)
            for mech in itertools.chain(ct.mechs.values(), ct.synapses.values()):
                for p in mech.parameters.values():
                    p.set_tolerance(tolerance)


def _convert_field(obj, field):
    constraint = Constraint.from_value(getattr(obj, field.name))
    setattr(obj, field.name, constraint)


ConstraintsDefinitionDict = typing.TypedDict(
    "ConstraintsDefinitionDict",
    {
        "cable_types": dict[str, CableTypeConstraintsDict],
        "synapse_types": dict[MechId, SynapseConstraintsDict],
    },
    total=False,
)


[docs]def define_constraints( constraints: ConstraintsDefinitionDict, tolerance=None, use_defaults=False ) -> ConstraintsDefinition: constraints = _parse_dict_def(ConstraintsDefinition, constraints) constraints.set_tolerance(tolerance) constraints.use_defaults = use_defaults return constraints