"""Ordered collections of schema entries representing a single section of the HED vocabulary."""
from typing import Union
from hed.schema.hed_schema_entry import HedSchemaEntry, UnitClassEntry, UnitEntry, HedTagEntry
from hed.schema.hed_schema_constants import HedSectionKey, HedKey, HedKeyOld
entries_by_section = {
HedSectionKey.Properties: HedSchemaEntry,
HedSectionKey.Attributes: HedSchemaEntry,
HedSectionKey.UnitModifiers: HedSchemaEntry,
HedSectionKey.Units: UnitEntry,
HedSectionKey.UnitClasses: UnitClassEntry,
HedSectionKey.ValueClasses: HedSchemaEntry,
HedSectionKey.Tags: HedTagEntry,
}
[docs]
class HedSchemaSection:
"""Typed container for all entries in one section of a loaded HED schema.
A :class:`~hed.schema.HedSchema` is divided into sections (tags, unit
classes, units, value classes, attributes, properties, unit modifiers).
Each section is a ``HedSchemaSection`` that maps lower-cased entry names
to their :class:`~hed.schema.hed_schema_entry.HedSchemaEntry` objects and
tracks which attributes are valid for that section.
The concrete entry type for each section is determined by
:data:`entries_by_section`:
+----------------------------+-------------------+
| Section key | Entry type |
+============================+===================+
| ``HedSectionKey.Tags`` | HedTagEntry |
| ``HedSectionKey.UnitClasses`` | UnitClassEntry |
| ``HedSectionKey.Units`` | UnitEntry |
| everything else | HedSchemaEntry |
+----------------------------+-------------------+
**Use this class directly when you need to:**
- Iterate over all entries in a specific schema section.
- Build schema comparison or diff tools.
- Access ``valid_attributes`` to determine which attributes are legal for
a given section.
Attributes:
all_names (dict[str, HedSchemaEntry]): Map from lower-cased name to entry.
all_entries (list[HedSchemaEntry]): Entries in insertion order.
valid_attributes (dict[str, HedSchemaEntry]): Attribute entries that are
declared valid for this section.
"""
def __init__(self, section_key, case_sensitive=True):
"""Construct schema section.
Parameters:
section_key (HedSectionKey): Name of the schema section.
case_sensitive (bool): If True, names are case-sensitive.
"""
# {lower_case_name: HedSchemaEntry}
self.all_names = {}
self._section_key = section_key
self.case_sensitive = case_sensitive
# Points to the entries in attributes
self.valid_attributes = {}
self._attribute_cache = {}
self._section_entry = entries_by_section.get(section_key)
self._duplicate_names = {}
self.all_entries = []
@property
def section_key(self):
"""Returns the HedSectionKey identifying this section.
Returns:
HedSectionKey: The key for this schema section.
"""
return self._section_key
@property
def duplicate_names(self):
"""Returns a dict of entries that share the same name within this section.
Returns:
dict: Mapping of lowercased name to a list of conflicting HedSchemaEntry objects.
"""
return self._duplicate_names
def _create_tag_entry(self, name):
new_entry = self._section_entry(name, self)
return new_entry
def _check_if_duplicate(self, name_key, new_entry):
return_entry = new_entry
if name_key in self.all_names:
if name_key not in self._duplicate_names:
self._duplicate_names[name_key] = [self.all_names[name_key]]
self._duplicate_names[name_key].append(new_entry)
else:
self.all_names[name_key] = new_entry
return return_entry
def _add_to_dict(self, name, new_entry):
"""Add a name to the dictionary for this section."""
name_key = name
if not self.case_sensitive:
name_key = name.casefold()
return_entry = self._check_if_duplicate(name_key, new_entry)
self.all_entries.append(new_entry)
return return_entry
[docs]
def get_entries_with_attribute(
self, attribute_name, return_name_only=False, schema_namespace=""
) -> list[Union[HedSchemaEntry, str]]:
"""Return entries or names with given attribute.
Parameters:
attribute_name (str): The name of the attribute(generally a HedKey entry).
return_name_only (bool): If True, return the name as a string rather than the tag entry.
schema_namespace (str): Prepends given namespace to each name if returning names.
Returns:
list[Union[HedSchemaEntry, str]]: List of HedSchemaEntry or strings representing the names.
"""
if attribute_name not in self._attribute_cache:
new_val = [tag_entry for tag_entry in self.values() if tag_entry.has_attribute(attribute_name)]
self._attribute_cache[attribute_name] = new_val
cache_val = self._attribute_cache[attribute_name]
if return_name_only:
return [f"{schema_namespace}{tag_entry.name}" for tag_entry in cache_val]
return cache_val
# ===============================================
# Simple wrapper functions to make this class primarily function as a dict
# ===============================================
def __iter__(self):
return iter(self.all_names)
def __len__(self):
return len(self.all_names)
[docs]
def items(self):
"""Return the items."""
return self.all_names.items()
[docs]
def values(self):
"""All names of the sections."""
return self.all_names.values()
[docs]
def keys(self):
"""The names of the keys."""
return self.all_names.keys()
def __getitem__(self, key):
if not self.case_sensitive:
key = key.casefold()
return self.all_names[key]
[docs]
def get(self, key):
"""Return the name associated with key.
Parameters:
key (str): The name of the key.
"""
try:
return self.__getitem__(key)
except KeyError:
return None
def __eq__(self, other):
if self.all_names != other.all_names:
return False
if self._section_key != other._section_key:
return False
if self.case_sensitive != other.case_sensitive:
return False
if self.duplicate_names != other.duplicate_names:
return False
return True
def __bool__(self):
return bool(self.all_names)
def _finalize_section(self, hed_schema):
for entry in self.all_entries:
entry.finalize_entry(hed_schema)
[docs]
class HedSchemaUnitSection(HedSchemaSection):
"""The schema section containing units."""
def _check_if_duplicate(self, name_key, new_entry):
"""We need to mark duplicate units(units with unitSymbol are case sensitive, while others are not."""
if not new_entry.has_attribute(HedKey.UnitSymbol):
name_key = name_key.casefold()
return super()._check_if_duplicate(name_key, new_entry)
def __getitem__(self, key):
"""Check the case of the key appropriately for symbols/not symbols, and return the matching entry."""
unit_entry = self.all_names.get(key)
if unit_entry is None:
unit_entry = self.all_names.get(key.casefold())
# Unit symbols must match exactly
if unit_entry is None or unit_entry.has_attribute(HedKey.UnitSymbol):
return None
return unit_entry
[docs]
class HedSchemaUnitClassSection(HedSchemaSection):
"""The schema section containing unit classes."""
def _check_if_duplicate(self, name_key, new_entry):
"""Allow adding units to existing unit classes, using a placeholder one with no attributes."""
if name_key in self and len(new_entry.attributes) == 1 and HedKey.InLibrary in new_entry.attributes:
return self.all_names[name_key]
return super()._check_if_duplicate(name_key, new_entry)