from nibabel import cifti2
from typing import Tuple, Sequence
from fsl.utils import path
import numpy as np
file_types = [
(3001, ['.dconn.nii'], (cifti2.BrainModelAxis, cifti2.BrainModelAxis)),
(3002, ['.dtseries.nii'], (cifti2.SeriesAxis, cifti2.BrainModelAxis)),
(3003, ['.pconn.nii'], (cifti2.ParcelsAxis, cifti2.ParcelsAxis)),
(3004, ['.ptseries.nii'], (cifti2.SeriesAxis, cifti2.ParcelsAxis)),
(3006, ['.dscalar.nii', '.dfan.nii'], (cifti2.ScalarAxis, cifti2.BrainModelAxis)),
(3007, ['.dlabel.nii'], (cifti2.LabelAxis, cifti2.BrainModelAxis)),
(3008, ['.pscalar.nii'], (cifti2.ScalarAxis, cifti2.ParcelsAxis)),
(3009, ['.pdconn.nii'], (cifti2.BrainModelAxis, cifti2.ParcelsAxis)),
(3010, ['.dpconn.nii'], (cifti2.ParcelsAxis, cifti2.BrainModelAxis)),
(3011, ['.pconnseries.nii'], (cifti2.ParcelsAxis, cifti2.ParcelsAxis, cifti2.SeriesAxis)),
(3012, ['.pconnscalar.nii'], (cifti2.ParcelsAxis, cifti2.ParcelsAxis, cifti2.ScalarAxis)),
]
[docs]def guess_extension(axes: Tuple[cifti2.Axis]) -> Sequence[str]:
"""
Guesses the extension based on the CIFTI axes
:param axes: CIFTI axes describing the rows/columns of a CIFTI file
:return: tuple of possible file extensions
"""
for _, extensions, axes_types in file_types:
if len(axes_types) == len(axes) and all(isinstance(a, at) for a, at in zip(axes, axes_types)):
return tuple(extensions)
return ()
[docs]def write(filename: str, arr: np.ndarray, axes: Tuple[cifti2.Axis]):
"""
Writes a CIFTI file guessing the extension of the filename
:param filename: full filename of basename
:param arr: array to be stored
:param axes: CIFTI axes describing the rows/columns of a CIFTI file
"""
extensions = guess_extension(axes)
if len(extensions) == 0:
raise ValueError("No valid extensions found for axes of type {}".format(type(a) for a in axes))
new_filename = path.addExt(filename, allowedExts=extensions, mustExist=False, defaultExt=extensions[0])
cifti2.write(new_filename, arr, axes)
def _greyordinate_index(brain_model: cifti2.BrainModelAxis):
assert (brain_model.name[0] == brain_model.name).all()
idx = np.zeros(brain_model.size, dtype='i4')
if not brain_model.surface_mask.all():
voxel = brain_model.voxel[~brain_model.surface_mask]
idx[brain_model.volume_mask] = brain_model.volume_shape[0] * (
brain_model.volume_shape[1] * voxel[:, 2] +
voxel[:, 1]
) + voxel[:, 0]
idx[brain_model.surface_mask] = -brain_model.vertex[brain_model.surface_mask]
return idx
def _find_overlap(bm1: cifti2.BrainModelAxis, bm2: cifti2.BrainModelAxis):
full_idx1 = []
full_idx2 = []
as_dict1 = {n: (i, bm) for n, i, bm in bm1.iter_structures()}
for name, idx2, bm2_part in bm2.iter_structures():
if name in as_dict1:
idx1, bm1_part = as_dict1[name]
_, sub_idx1, sub_idx2 = np.intersect1d(
_greyordinate_index(bm1_part),
_greyordinate_index(bm2_part),
return_indices=True
)
sorter = np.argsort(sub_idx1)
full_idx1.append(idx1.start + sub_idx1[sorter])
full_idx2.append(idx2.start + sub_idx2[sorter])
full_idx1 = np.concatenate(full_idx1, 0)
full_idx2 = np.concatenate(full_idx2, 0)
assert (len(full_idx1) == 0 and len(full_idx2) == 0) or (bm1[full_idx1] == bm2[full_idx2])
return full_idx1, full_idx2
[docs]def combine(brain_models: Sequence[cifti2.BrainModelAxis]):
"""
Find the common space of multiple BrainModel axes
:param brain_models: sequence of brain model axes
:return: tuple of common brain model and sequence of indices with the common space
"""
common_bm = brain_models[0]
for bm in brain_models[1:]:
common_bm = common_bm[_find_overlap(common_bm, bm)[0]]
return common_bm, [_find_overlap(bm, common_bm)[0] for bm in brain_models]
[docs]def axis_from_hdf5(group: "h5py.Group"):
"""
Stores the information from an axis in HDF5 group
"""
name = group.attrs['name']
if name == 'None':
return None
if name == 'Scalar':
return cifti2.ScalarAxis(np.array(group['name']).astype('U'))
if name == 'BrainModel':
nvertices = {str(key): int(value) for key, value in
zip(group.attrs['nvertices_keys'], group.attrs['nvertices_values'])}
return cifti2.BrainModelAxis(
np.array(group['name']).astype('U'), np.array(group['voxel']),
np.array(group['vertex']), np.array(group['affine']) if 'affine' in group else None,
tuple(int(sz) for sz in group['volume_shape']) if 'volume_shape' in group else None,
nvertices)
if name == 'Parcels':
nvertices = {str(key): int(value) for key, value in
zip(group.attrs['nvertices_keys'], group.attrs['nvertices_values'])}
voxels = [np.array(group[f'voxels{idx}']) for idx in range(len(group['name']))]
vertices = []
for idx in range(len(group['name'])):
res = {}
for name in group.attrs['nvertices_keys']:
fname = f'vertices{idx}_{name}'
if fname in group:
res[name] = np.array(group[fname])
vertices.append(res)
return cifti2.ParcelsAxis(
np.array(group['name']).astype('S'), voxels, vertices,
np.array(group['affine']) if 'affine' in group else None,
tuple(int(sz) for sz in group['volume_shape']) if 'volume_shape' in group else None,
nvertices)
if name == 'Series':
return cifti2.SeriesAxis(
group.attrs['start'], group.attrs['step'], group.attrs['size'], group.attrs['unit']
)
raise ValueError(f"Reading {name} from HDF5 is not currently supported")
[docs]def axis_to_hdf5(group: "h5py.Group", axis: cifti2.Axis):
"""
Stores the information from an axis in HDF5 group
"""
if axis is None:
group.attrs['name'] = 'None'
elif isinstance(axis, cifti2.ScalarAxis):
group.attrs['name'] = 'Scalar'
group['name'] = axis.name.astype('S')
elif isinstance(axis, cifti2.BrainModelAxis):
group.attrs['name'] = 'BrainModel'
group['name'] = axis.name.astype('S')
group['voxel'] = axis.voxel
group['vertex'] = axis.vertex
if axis.affine is not None:
group['affine'] = axis.affine
if axis.volume_shape is not None:
group['volume_shape'] = axis.volume_shape
group.attrs['nvertices_keys'] = list(axis.nvertices.keys())
group.attrs['nvertices_values'] = list(axis.nvertices.values())
elif isinstance(axis, cifti2.ParcelsAxis):
group.attrs['name'] = 'Parcels'
for idx in range(len(axis)):
for key, value in axis.vertices[idx].items():
if key not in axis.nvertices:
raise KeyError(f"Defining vertices for undefined surface {key}")
group[f'vertices{idx}_{key}'] = value
group[f'voxels{idx}'] = axis.voxels[idx]
group['affine'] = axis.affine
group['name'] = axis.name.astype('S')
group.attrs['nvertices_keys'] = list(axis.nvertices.keys())
group.attrs['nvertices_values'] = list(axis.nvertices.values())
group['volume_shape'] = axis.volume_shape
elif isinstance(axis, cifti2.SeriesAxis):
group.attrs['name'] = 'Series'
group.attrs['start'] = axis.start
group.attrs['size'] = axis.size
group.attrs['step'] = axis.step
group.attrs['unit'] = axis.unit
else:
raise ValueError(f"storing {axis.__class__.__name__} in HDF5 is not currently supported")
[docs]def from_hdf5(group: "h5py.Group") -> Tuple["h5py.Dataset", Sequence[cifti2.Axis]]:
"""
Reads a CIFTI array from the HDF5 format
:param group: HDF5 group the data was stored in
:return: tuple with data array (still on disk) and sequence of axes
"""
data = group['data']
axes = [axis_from_hdf5(group[f'axis{idx}']) for idx in range(data.ndim)]
return data, axes
[docs]def to_hdf5(group: "h5py.Group", arr, axes: Sequence[cifti2.Axis], compression='gzip'):
"""
Store the CIFTI array in an HDF5 format
:param group: HDF5 group to store the data in (can be top-level HDF5 file)
:param arr: data array
:param axes: sequence of axes (optionally None)
:param compression: which compression to use on the main data array (None, 'gzip', or 'lzf')
"""
assert len(axes) == arr.ndim
assert all(ax is None or len(ax) == sz for sz, ax in zip(arr.shape, axes))
group.create_dataset('data', data=arr, compression=compression)
for idx, axis in enumerate(axes):
axis_to_hdf5(group.create_group(f'axis{idx}'), axis)
[docs]def empty_hdf5(group: "h5py.Group", axes: Sequence[cifti2.Axis], dtype=float, compression='gzip'):
"""
Creates a new HDF5 group with an empty dataset
:param group: HDF5 group
:param axes: sequence of axes (all have to be defined)
:param dtype: data type
:param compression: which compression to use on each chunk
:return: new array to be filled
"""
group.create_dataset('data', shape=(len(ax) for ax in axes), dtype=dtype, compression=compression)
for idx, axis in enumerate(axes):
axis_to_hdf5(group.create_group(f'axis{idx}'), axis)
return group['data']
[docs]def empty_zarr(group: "zarr.Group", axes: Sequence[cifti2.Axis], dtype=float, compressor='default'):
"""
Creates a new zarr group with an empty dataset
:param group: zarr group
:param axes: sequence of axes (all have to be defined)
:param dtype: data type
:param compressor: which compressor to use on each chunk
:return: new array to be filled
"""
group.create_dataset('data', shape=(len(ax) for ax in axes), dtype=dtype, compressor=compressor)
for idx, axis in enumerate(axes):
axis_to_hdf5(group.create_group(f'axis{idx}'), axis)
return group['data']