Source code for arborize.builders._arbor

import dataclasses
import typing
from collections import defaultdict
from itertools import tee
from math import isnan

if typing.TYPE_CHECKING:
    import arbor

    from .. import CableType
    from ..schematic import CableBranch, Point, Schematic


class CableCellTemplate:
    def __init__(
        self,
        morphology: "arbor.morphology",
        labels: "arbor.label_dict",
        decor: "arbor.decor",
    ):
        self.morphology = morphology
        self.labels = labels
        self.decor = decor

    def build(self):
        import arbor

        return arbor.cable_cell(self.morphology, self.decor, self.labels)


def hash_labelset(labels: list[str]):
    return "&".join(l.replace("&", "&&") for l in sorted(labels))


def get_label_dict(schematic: "Schematic"):
    import arbor

    labelsets: dict[str, int] = {}
    label_dict = defaultdict(list)
    for b in schematic.cables:
        for p in b.points:
            h = hash_labelset(p.branch.labels)
            if h not in labelsets:
                lset_id = len(labelsets)
                for l in p.branch.labels:
                    label_dict[l].append(lset_id)
                labelsets[h] = lset_id
    return labelsets, arbor.label_dict(
        {
            label: (
                "(join " + " ".join(f"(tag {tag})" for tag in tags) + ")"
                if len(tags) > 1
                else f"(tag {tags[0]})"
            )
            for label, tags in label_dict.items()
        }
    )


def _to_units(value, unit: "arbor.units.unit") -> "arbor.units.quantity":
    import arbor

    # todo: drop when units are released
    if hasattr(arbor, "units") and isinstance(value, arbor.units.quantity):
        ret = value.value_as(unit)
        if isnan(ret):
            raise ValueError(f"Can't convert {value.units} to {unit}.")
        return ret * unit
    else:
        return value * unit


def paint_cable_type_cable(decor: "arbor.decor", label: str, cable_type: "CableType"):
    import arbor

    decor.paint(
        f'"{label}"',
        cm=_to_units(
            cable_type.cable.cm,
            (arbor.units.F / arbor.units.m2) if hasattr(arbor, "units") else 1,
        ),
        rL=_to_units(
            cable_type.cable.Ra,
            (arbor.units.Ohm * arbor.units.cm) if hasattr(arbor, "units") else 1,
        ),
    )


def paint_cable_type_ions(decor: "arbor.decor", label: str, cable_type: "CableType"):
    import arbor

    units = (
        {
            "rev_pot": arbor.units.mV,
            "int_con": arbor.units.mM,
            "ext_con": arbor.units.mM,
        }
        if hasattr(arbor, "units")
        # todo: drop when units are released
        else defaultdict(lambda: 1)
    )
    for ion_name, ion in cable_type.ions.items():
        try:
            decor.paint(
                f'"{label}"',
                ion=ion_name,
                **{
                    k: _to_units(v, units[k])
                    for k, v in dataclasses.asdict(ion).items()
                },
            )
        except TypeError:
            # todo: drop when units are released
            # Support older `ion_name` kwarg
            decor.paint(
                f'"{label}"',
                ion_name=ion_name,
                **{
                    k: _to_units(v, units[k])
                    for k, v in dataclasses.asdict(ion).items()
                },
            )


def paint_cable_type_mechanisms(
    decor: "arbor.decor", label: str, cable_type: "CableType"
):
    import arbor

    for mech_id, mech in cable_type.mechs.items():
        decor.paint(f'"{label}"', arbor.density(mech_id, mech.parameters))


def paint_cable_type(decor: "arbor.decor", label: str, cable_type: "CableType"):
    paint_cable_type_cable(decor, label, cable_type)
    paint_cable_type_ions(decor, label, cable_type)
    paint_cable_type_mechanisms(decor, label, cable_type)


def get_decor(schematic: "Schematic"):
    import arbor

    decor = arbor.decor()
    for label, cable_type in schematic.definition.get_cable_types().items():
        paint_cable_type(decor, label, cable_type)

    return decor


[docs]def arbor_build(schematic: "Schematic"): import arbor schematic.freeze() if not hasattr(schematic, "arbor"): tree = arbor.segment_tree() # Stores the ids of the segments to append to. branch_endpoints: dict["CableBranch", int] = {} labelsets, label_dict = get_label_dict(schematic) for bid, branch in enumerate(schematic.cables): if len(branch.points) < 2: # Empty branches mess up the branch id numbering, so we forbid them raise RuntimeError(f"Branch {bid} needs at least 2 points.") pts = iter(branch.points) # Set up pairwise iterators pts_a, pts_b = tee(pts) next(pts_b) # Start branch from the endpoint, if the branch has a parent. ptid = branch_endpoints[branch.parent] if branch.parent else arbor.mnpos for i, (p1, p2) in enumerate(zip(pts_a, pts_b)): # Tag it with a unique tag per label combination tag = hash_labelset(p2.branch.labels) ptid = tree.append(ptid, _mkpt(p1), _mkpt(p2), tag=labelsets.get(tag)) branch_endpoints[branch] = ptid schematic.arbor = CableCellTemplate( arbor.morphology(tree), label_dict, get_decor(schematic), ) return schematic.arbor.build()
def _mkpt(p: "Point") -> "arbor.mpoint": import arbor return arbor.mpoint(*p.coords, p.radius)