Source code for arborize.schematic

import typing
from collections import deque
from typing import Iterable, Optional, Union

import errr

from ._util import get_location_name
from .definitions import CableType, ModelDefinition
from .exceptions import ConstructionError, FrozenError, ModelDefinitionError

if typing.TYPE_CHECKING:
    from builders._arbor import CableCellTemplate

    from .parameter import Parameter

Location = tuple[int, int]
Interval = tuple[Location, Location]


[docs]def throw_frozen(): raise FrozenError("Can't alter finished schematic.")
def _random_name(): import random import string return "".join(random.choices(string.ascii_uppercase, k=10))
[docs]class Schematic: """ A schematic is an intermediate object that associates parameter definitions to points in space. You can define locations (3d coords + radius) and tag them with labels, or set parameters directly on the locations. You can pass a schematic to a Builder, which will freeze the schematic (no changes can be made anymore) and create a simulator specific instance of the model. Schematics create a user-facing layer of "virtual branches", which is the network graph of the created locations. However, NEURON does not support the resolution that arbor does, so an underlying layer of "true branches" is created. In NEURON, a map is kept on the model between the locations on the virtual branches and the locations on the true branches, so that we can arbitrarily split up true branches into smaller pieces to achieve the resolution we need. """ arbor: typing.Optional["CableCellTemplate"] def __init__(self, name=None): self._name = name self._frozen = False self._definition: ModelDefinition = ModelDefinition() self.cables: list["CableBranch"] = [] self.roots: list["UnitBranch"] = [] self._named = 0 def __iter__(self) -> typing.Iterator["UnitBranch"]: """ Iterate over the unit branches depth-first order. """ stack: deque["UnitBranch"] = deque(self.roots) while True: try: branch = stack.pop() except IndexError: break yield branch if branch.children: stack.extend(reversed(branch.children)) def __len__(self): return len([*iter(self)]) @property def name(self): """ Base name for all the instances of this model. Suffixed unique names for each instance can be obtained by calling ``create_name``. """ return self._name @name.setter def name(self, value): if self._frozen: raise FrozenError("Can't change name of finished schematic.") else: self._name = value @property def definition(self): """ Definition of the model, contains the definition of the parameters for the cables, mechanisms, and synapses of this model. """ return self._definition.copy() @definition.setter def definition(self, value): if self._frozen: raise FrozenError("Can't change definitions of finished schematic.") else: self._definition = value
[docs] def create_name(self): """ Generate the next unique name for an instance of this model. """ if not self._frozen: raise FrozenError( "Schematic must be finished before naming instances of it." ) self._named += 1 return f"{self._name}_{self._named}"
[docs] def create_location( self, location: tuple[int, int], coords, radii, labels, endpoint=None ): """ Add a new location to the schematic. A location is a tuple of the branch id and point-on-branch id. Locations must be appended in ascending order. :param location: :param coords: :param radii: :param labels: :param endpoint: :return: """ if self._frozen: throw_frozen() bid, pid = location next_bid = len(self.cables) if bid == next_bid: # We are starting a new branch branch = CableBranch() self.cables.append(branch) elif bid == next_bid - 1: # We are continuing the same branch branch = self.cables[bid] else: # Ascending branch order violated next_loc = f"({next_bid - 1}.{len(self.cables[next_bid - 1].points)})" raise ConstructionError( f"Locations need to be constructed in order. Can't construct " f"{location}, should construct {next_loc} or ({next_bid}.0)." ) if pid != len(branch.points): # Ascending point order violated raise ConstructionError( f"Locations need to be constructed in order. Can't construct {location}" f", should construct ({bid}, {len(branch.points)}) or ({next_bid}.0)." ) # We append the point to the cable, this may create new units. point = branch.append(location, coords, radii, labels) if endpoint: # If an endpoint was passed, we should set that as our parent both at the # cable and unit level. cable_parent = self.cables[endpoint[0]] unit_parent = cable_parent.points[endpoint[1]].branch # Set the child's parent cable branch.parent = cable_parent # Add the parent's child cable cable_parent.children.append(branch) # Set the child's parent unit point.branch.parent = unit_parent # Add the parent's child unit unit_parent.children.append(point.branch) elif pid == 0: # Otherwise, the first point of a branch without an endpoint should be added # to the roots of the schematic. self.roots.append(point.branch)
[docs] def create_empty(self): """Create an empty branch""" if self._frozen: throw_frozen() self.cables.append(CableBranch())
[docs] def set_param(self, location: Union[Location, Interval, str], param: "Parameter"): if isinstance(location, str): # Set parameter for the global label definition self.definition[location].set(param) else: # Set parameter on the specific location or interval raise NotImplementedError( "Location or interval parameters not implemented yet." )
[docs] def freeze(self): """Freeze the schematic. Most mutating operations will no longer be permitted.""" if not self._frozen: self._flatten_branches(self.roots) self._name = self._name if self._name is not None else _random_name() self._frozen = True # If we are a constraint schematic, reconvert after freezing. if hasattr(self.definition, "convert_to_constraints"): self.definition.convert_to_constraints() # fixme: ion defaults are not constraints from .constraints import Constraint for branch in self: for ion in branch.definition.ions.values(): for prop, value in ion: setattr(ion, prop, Constraint.from_value(value))
def _flatten_branches(self, branches: Iterable["UnitBranch"]): for branch in branches: # Concretize the true branch definition by merging all labels and params. branch.definition = self._makedef(branch.labels) try: # Assert that none of the values are missing (= `None`) branch.definition.assert_() except ValueError as e: locstr = get_location_name(branch.points) if not branch.labels: raise ValueError( f"Unlabeled {locstr} is missing value for {e.args[1]}." ) from None raise ModelDefinitionError( f"{locstr} labelled {errr.quotejoin(branch.labels)} " f"misses value for {e.args[1:]}" ) from None self._flatten_branches(branch.children) def _makedef(self, labels: typing.Sequence[str]) -> CableType: # Determine the cable type priority order based on the key order in the dict. sort_labels = self._make_label_sorter() return self.definition.cable_type_class.anchor( (self._definition._cable_types.get(label) for label in sort_labels(labels)), synapses=self._definition.get_synapse_types(), use_defaults=self.definition.use_defaults, ion_class=self._definition.ion_class, )
[docs] def get_cable_types(self): return self._definition.get_cable_types()
[docs] def get_synapse_types(self): return self._definition.get_synapse_types()
[docs] def get_compound_cable_types(self): if not self._frozen: raise RuntimeError("Can only compound cable types in frozen schematic.") name_labels = self._make_label_namer() return {name_labels(branch.labels): branch.definition for branch in self}
def _make_label_sorter(self): insert_index = [*self._definition._cable_types.keys()].index len_ = len(self._definition._cable_types) def label_order(lbl): try: insert = insert_index(lbl) except ValueError: insert = -1 return (insert, lbl) return lambda labels: sorted(labels, key=label_order) def _make_label_namer(self): sort_labels = self._make_label_sorter() return lambda labels: "_".join( l.replace("_", "__") for l in sort_labels(labels) )
[docs]class Point: def __init__(self, loc, branch: "UnitBranch", coords, radius): self.loc = loc self.coords = coords self.radius = radius self.branch = branch
[docs]class Branch: points: list[Point] parent: Optional["Branch"] children: list["Branch"] def __init__(self): self.points = [] self.parent = None self.children = []
[docs]class CableBranch(Branch): parent: Optional["CableBranch"] children: list["CableBranch"]
[docs] def append(self, loc, coords, radius, labels): if len(self.points): prev = self.points[-1] if prev.branch.labels == labels: # If we have the same labels, continue growing the true branch branch = prev.branch else: # If the labels change, create a new true branch branch = UnitBranch() branch.parent = prev.branch prev.branch.children.append(branch) else: branch = UnitBranch() branch.labels = labels.copy() point = Point(loc, branch, coords, radius) branch.points.append(point) self.points.append(point) return point
[docs]class UnitBranch(Branch): parent: Optional["UnitBranch"] children: list["UnitBranch"] labels: list[str] definition: CableType
[docs] def append(self, point): self.points.append(point)