import logging
import numpy
logger = logging.getLogger(__name__)
PROTOCOL_VU_TO_BBP = {
"X1PS_SubThresh_DA_0": "IV",
"X2LP_Search_DA_0": "IDthresh",
"X4PS_SupraThresh_DA_0": "IDrest",
"CCSteps_DA_0": "Step"
}
def _normalize_scala_protocol_name(protocol_name):
"""Normalize protocol labels using the Scala reader rules."""
if isinstance(protocol_name, bytes):
protocol_name = protocol_name.decode("UTF-8")
protocol_name_lower = protocol_name.lower()
if protocol_name_lower == "na" or (
"step" in protocol_name_lower and protocol_name_lower != "genericstep"
):
return "Step"
return protocol_name
class NWBReader:
def __init__(self, content, target_protocols, repetition=None, v_file=None):
""" Init
Args:
content (h5.File): NWB file
target_protocols (list of str): list of the protocols to be read and returned
repetition (list of int): id of the repetition(s) to be read and returned
v_file (str): name of original file that can be retrieved in sweep's description
"""
self.content = content
self.target_protocols = target_protocols
self.repetition = repetition
self.v_file = v_file
def read(self):
""" Read the content of the NWB file
Returns:
data (list of dict): list of traces"""
raise NotImplementedError()
def _format_nwb_trace(self, voltage, current, start_time, trace_name=None, repetition=None):
""" Format the data from the NWB file to the format used by BluePyEfe
Args:
voltage (Dataset): voltage series
current (Dataset): current series
start_time (Dataset): starting time
trace_name (Dataset): name of the trace
repetition (int): repetition number
Returns:
dict: formatted trace
"""
v_array = numpy.array(
voltage[()] * voltage.attrs["conversion"], dtype="float32"
)
i_array = numpy.array(
current[()] * current.attrs["conversion"], dtype="float32"
)
dt = 1. / float(start_time.attrs["rate"])
v_unit = voltage.attrs["unit"]
i_unit = current.attrs["unit"]
t_unit = start_time.attrs["unit"]
if not isinstance(v_unit, str):
v_unit = voltage.attrs["unit"].decode('UTF-8')
i_unit = current.attrs["unit"].decode('UTF-8')
t_unit = start_time.attrs["unit"].decode('UTF-8')
return {
"voltage": v_array,
"current": i_array,
"dt": dt,
"id": str(trace_name),
"repetition": repetition,
"i_unit": i_unit,
"v_unit": v_unit,
"t_unit": t_unit,
}
class AIBSNWBReader(NWBReader):
def read(self):
""" Read the content of the NWB file
Returns:
data (list of dict): list of traces"""
data = []
for sweep in list(self.content["acquisition"]["timeseries"].keys()):
protocol_name = self.content["acquisition"]["timeseries"][sweep]["aibs_stimulus_name"][()]
if not isinstance(protocol_name, str):
protocol_name = protocol_name.decode('UTF-8')
if (
self.target_protocols and
protocol_name.lower() not in [prot.lower() for prot in self.target_protocols]
):
continue
data.append(self._format_nwb_trace(
voltage=self.content["acquisition"]["timeseries"][sweep]["data"],
current=self.content["stimulus"]["presentation"][sweep]["data"],
start_time=self.content["acquisition"]["timeseries"][sweep]["starting_time"],
trace_name=sweep
))
return data
class ScalaNWBReader(NWBReader):
def read(self):
""" Read and format the content of the NWB file
Returns:
data (list of dict): list of traces
"""
data = []
if self.repetition:
repetitions_content = self.content['general']['intracellular_ephys']['intracellular_recordings']['repetition']
if isinstance(self.repetition, (int, str)):
self.repetition = [int(self.repetition)]
for sweep in list(self.content['acquisition'].keys()):
key_current = sweep.replace('Series', 'StimulusSeries')
try:
protocol_name = self.content["acquisition"][sweep].attrs["stimulus_description"]
except KeyError:
logger.warning(f'Could not find "stimulus_description" attribute for {sweep}, Setting it as "Step"')
protocol_name = "Step"
protocol_name = _normalize_scala_protocol_name(protocol_name)
if (
self.target_protocols and
protocol_name.lower() not in [prot.lower() for prot in self.target_protocols]
):
continue
if key_current not in self.content['stimulus']['presentation']:
continue
if self.repetition:
sweep_id = int(sweep.split("_")[-1])
if (int(repetitions_content[sweep_id]) in self.repetition):
data.append(self._format_nwb_trace(
voltage=self.content['acquisition'][sweep]['data'],
current=self.content['stimulus']['presentation'][key_current]['data'],
start_time=self.content['acquisition'][sweep]["starting_time"],
trace_name=sweep,
repetition=int(repetitions_content[sweep_id])
))
else:
data.append(self._format_nwb_trace(
voltage=self.content['acquisition'][sweep]['data'],
current=self.content['stimulus']['presentation'][key_current]['data'],
start_time=self.content["acquisition"][sweep]["starting_time"],
trace_name=sweep,
))
return data
class BBPNWBReader(NWBReader):
def _get_repetition_keys_nwb(self, ecode_content, request_repetitions=None):
""" Filter the names of the traces based on the requested repetitions
Args:
ecode_content (dict): content of the NWB file for one eCode/protocol
request_repetitions (list of int): identifier of the requested repetitions
Returns:
list of str: list of the keys of the traces to be read
"""
if isinstance(request_repetitions, (int, str)):
request_repetitions = [int(request_repetitions)]
reps = list(ecode_content.keys())
reps_id = [int(rep.replace("repetition ", "")) for rep in reps]
if request_repetitions:
return [reps[reps_id.index(i)] for i in request_repetitions]
else:
return list(ecode_content.keys())
def read(self):
""" Read and format the content of the NWB file
Returns:
data (list of dict): list of traces
"""
data = []
for ecode in self.target_protocols:
for cell_id in self.content["data_organization"].keys():
if ecode not in self.content["data_organization"][cell_id]:
new_ecode = next(
iter(
ec
for ec in self.content["data_organization"][cell_id]
if ec.lower() == ecode.lower()
),
None
)
if new_ecode:
logger.debug(
f"Could not find {ecode} in nwb file, will use {new_ecode} instead"
)
ecode = new_ecode
else:
logger.debug(f"No eCode {ecode} in nwb.")
continue
ecode_content = self.content["data_organization"][cell_id][ecode]
rep_iter = self._get_repetition_keys_nwb(
ecode_content, request_repetitions=self.repetition
)
for rep in rep_iter:
for sweep in ecode_content[rep].keys():
for trace_name in list(ecode_content[rep][sweep].keys()):
if "ccs_" in trace_name:
key_current = trace_name.replace("ccs_", "ccss_")
elif "ic_" in trace_name:
key_current = trace_name.replace("ic_", "ics_")
else:
continue
if key_current not in self.content["stimulus"]["presentation"]:
logger.debug(f"Ignoring {key_current} not"
" present in the stimulus presentation")
continue
if trace_name not in self.content["acquisition"]:
logger.debug(f"Ignoring {trace_name} not"
" present in the acquisition")
continue
# if we have v_file, check that trace comes from this original file
if self.v_file is not None:
attrs = self.content["acquisition"][trace_name].attrs
if "description" not in attrs:
logger.warning(
"Ignoring %s because no description could be found.",
trace_name
)
continue
v_file_end = self.v_file.split("/")[-1]
if v_file_end != attrs.get("description", "").split("/")[-1]:
logger.debug(f"Ignoring {trace_name} not matching v_file")
continue
data.append(self._format_nwb_trace(
voltage=self.content["acquisition"][trace_name]["data"],
current=self.content["stimulus"]["presentation"][key_current][
"data"],
start_time=self.content["stimulus"]["presentation"][key_current][
"starting_time"],
trace_name=trace_name,
repetition=int(rep.replace("repetition ", ""))
))
return data
[docs]
class TRTNWBReader(NWBReader):
"""Read NWB files used in 'An in vitro whole-cell electrophysiology dataset of
human cortical neurons' by Howard, Derek et al., 2022, doi.org/10.1093/gigascience/giac108.
The files that can be read by this reader can be found at
10.48324/dandi.000293/0.220708.1652 (human), and
10.48324/dandi.000292/0.220708.1652 (mouse).
"""
[docs]
def read(self):
""" Read and format the content of the NWB file
Returns:
data (list of dict): list of traces
"""
data = []
# Only return data if target_protocols is None or includes "step" or "genericstep"
if self.target_protocols:
allowed = [p.lower() for p in self.target_protocols]
if "step" not in allowed and "genericstep" not in allowed:
logger.warning(
"TRTNWBReader only supports 'step' and 'genericstep' protocols, "
f"but requested: {self.target_protocols}. Skipping."
)
return []
# possible paths in content:
# /acquisition/index_00
# or /acquisition/index_000
# or /acquisition/Index_0_0_0
for voltage_sweep_name, voltage_sweep in list(self.content["acquisition"].items()):
parts = voltage_sweep_name.split("_")
if len(parts) == 2:
# maps 00 -> 01, 01 -> 03, ... or 000 -> 001, etc.
str_size = len(parts[-1])
parts[-1] = str(2 * int(parts[-1]) + 1).rjust(str_size, "0")
else:
# maps 0_0_0 -> 0_0_1, 0_0_1 -> 0_0_0, etc.
if parts[-1] == "0":
parts[-1] = "1"
elif parts[-1] == "1":
parts[-1] = "0"
elif parts[-1] == "2":
parts[-1] = "3"
elif parts[-1] == "3":
parts[-1] = "2"
current_sweep_name = "_".join(parts)
# possible paths in content:
# /stimulus/presentation/index_01
# or /stimulus/presentation/index_001
# or /stimulus/presentation/Index_0_0_1
current_sweep = self.content["stimulus"]["presentation"][current_sweep_name]
data.append(self._format_nwb_trace(
voltage=voltage_sweep["data"],
current=current_sweep["data"],
start_time=voltage_sweep["starting_time"],
trace_name=voltage_sweep_name
))
return data
def _format_nwb_trace(self, voltage, current, start_time, trace_name=None, repetition=None):
""" Format the data from the NWB file to the format used by BluePyEfe
Args:
voltage (Dataset): voltage series
current (Dataset): current series
start_time (Dataset): starting time
trace_name (Dataset): name of the trace
repetition (int): repetition number
Returns:
dict: formatted trace
"""
v_conversion = voltage.attrs["conversion"]
i_conversion = current.attrs["conversion"]
v_unit = voltage.attrs["unit"]
i_unit = current.attrs["unit"]
t_unit = start_time.attrs["unit"]
if not isinstance(v_unit, str):
v_unit = voltage.attrs["unit"].decode('UTF-8')
i_unit = current.attrs["unit"].decode('UTF-8')
t_unit = start_time.attrs["unit"].decode('UTF-8')
if (
v_conversion == 1e-12 and
i_conversion == 0.001 and
v_unit == "volts" and
i_unit == "volts"
):
# big mixup in units, correct it
v_conversion = 1e-3
i_conversion = 1e-12
i_unit = "amperes"
v_array = numpy.array(
voltage[()] * v_conversion, dtype="float32"
)
i_array = numpy.array(
current[()] * i_conversion, dtype="float32"
)
dt = 1. / float(start_time.attrs["rate"])
return {
"voltage": v_array,
"current": i_array,
"dt": dt,
"id": str(trace_name),
"repetition": repetition,
"i_unit": i_unit,
"v_unit": v_unit,
"t_unit": t_unit,
}
class VUNWBReader(NWBReader):
def __init__(self, content, target_protocols, in_data, repetition=None):
""" Init
Args:
content (h5.File): NWB file
target_protocols (list of str): list of the protocols to be read and returned
repetition (list of int): id of the repetition(s) to be read and returned
"""
self.content = content
self.target_protocols = target_protocols
self.repetition = repetition
self.in_data = in_data
def _get_target_protocols(self):
target_protocols = self.in_data.get("protocol_name", self.target_protocols)
if isinstance(target_protocols, str):
return [target_protocols]
return target_protocols
def read(self):
""" Read and format the content of the NWB file
Returns:
data (list of dict): list of traces
"""
data = []
target_protocols = self._get_target_protocols()
for sweep_name, current_sweep in list(self.content["stimulus"]["presentation"].items()):
stimulus_description = None
try:
stimulus_description = current_sweep.attrs["stimulus_description"]
except KeyError:
stimulus_description = current_sweep["stimulus_description"][()][0].decode('UTF-8')
if stimulus_description not in PROTOCOL_VU_TO_BBP:
continue
translated_name = PROTOCOL_VU_TO_BBP[stimulus_description]
if translated_name not in target_protocols:
continue
voltage_sweep_name = sweep_name.replace("DA", "AD")
voltage_sweeps = self.content["acquisition"]["timeseries"] if "timeseries" in self.content["acquisition"] else self.content["acquisition"]
if voltage_sweep_name not in voltage_sweeps:
continue
data.append(self._format_nwb_trace(
voltage=voltage_sweeps[voltage_sweep_name]["data"],
current=current_sweep["data"],
start_time=voltage_sweeps[voltage_sweep_name]["starting_time"],
trace_name=sweep_name
))
# Shorten protocols that finish with NaNs
first_nan = numpy.argmax(numpy.isnan(data[-1]["current"]))
if first_nan:
data[-1]["voltage"] = data[-1]["voltage"][:first_nan]
data[-1]["current"] = data[-1]["current"][:first_nan]
# Remove the protocols that finish too early
if "toff" in self.in_data and self.in_data["toff"] > len(data[-1]["current"]) * data[-1]["dt"] * 1000:
data.pop(-1)
else:
# Offset the current with the holding current
holding_current = float(voltage_sweeps[voltage_sweep_name]["bias_current"][()]) * 1e-12 # in pA
data[-1]["current"] = numpy.asarray(data[-1]["current"]) + holding_current
# For Step, IV and IDRest protocols, replace the first 90 ms with the value at 90 ms
# if stimulus_description == "CCSteps_DA_0":
if any(stimulus_description in s for s in ["CCSteps_DA_0", "X1PS_SubThresh_DA_0", "X4PS_SupraThresh_DA_0"]):
if int(0.090 / data[-1]["dt"]) < len(data[-1]["current"]):
data[-1]["current"][0:int(0.090 / data[-1]["dt"])] = data[-1]["current"][int(0.090 / data[-1]["dt"])]
data[-1]["voltage"][0:int(0.090 / data[-1]["dt"])] = data[-1]["voltage"][int(0.090 / data[-1]["dt"])]
else:
# Handle the case when the index is out of bounds
# You can choose to raise an exception, set a default value, or handle it in a different way
logger.info(f"For {stimulus_description}, unable to replace 0-40 ms value with the one at 40th ms as current/voltage array is too short")
continue
return data