'''Python implementation of multidimensional GRAPPA.'''

from collections import defaultdict

import numpy as np
from skimage.util import view_as_windows

[docs]def mdgrappa( kspace, calib=None, kernel_size=None, coil_axis=-1, lamda=0.01, nnz=None, weights=None, ret_weights=False): '''GeneRalized Autocalibrating Partially Parallel Acquisitions. Parameters ---------- kspace : N-D array Measured undersampled complex k-space data. N-1 dimensions hold spatial frequency axes (kx, ky, kz, etc.). 1 dimension holds coil images (`coil_axis`). The missing entries should have exactly 0. calib : N-D array or None, optional Fully sampled calibration data. If `None`, calibration data will be extracted from the largest possible hypercube with origin at the center of k-space. kernel_size : tuple or None, optional The size of the N-1 dimensional GRAPPA kernels: (kx, ky, ...). Default: (5,)*(kspace.ndim-1) coil_axis : int, optional Dimension holding coil images. lamda : float, optional Tikhonov regularization constant for kernel calibration. nnz : int or None, optional Number of nonzero elements in a multidimensional patch required to train/apply a kernel. Default: `sqrt(prod(kernel_size))`. weights : dict, optional Maps sampling patterns to trained kernels. ret_weights : bool, optional Return the trained weights as a dictionary mapping sampling patterns to kernels. Default is ``False``. Returns ------- res : array_like k-space data where missing entries have been filled in. weights : dict, optional Returned if ``ret_weights=True``. Notes ----- Based on the GRAPPA algorithm described in [1]_. All axes (except coil axis) are used for GRAPPA reconstruction. References ---------- .. [1] Griswold, Mark A., et al. "Generalized autocalibrating partially parallel acquisitions (GRAPPA)." Magnetic Resonance in Medicine: An Official Journal of the International Society for Magnetic Resonance in Medicine 47.6 (2002): 1202-1210. ''' # coils to the back kspace = np.moveaxis(kspace, coil_axis, -1) nc = kspace.shape[-1] # Make sure we have a kernel_size if kernel_size is None: kernel_size = (5,)*(kspace.ndim-1) assert len(kernel_size) == kspace.ndim-1, ( 'kernel_size must have %d entries' % (kspace.ndim-1)) # Only consider sampling patterns that have at least nnz samples if nnz is None: nnz = int(np.sqrt( # User can supply calibration region separately or we can find it if calib is not None: calib = np.moveaxis(calib, coil_axis, -1) else: # Find the calibration region and split it out from kspace raise NotImplementedError("Auto ACS extraction not implemented!") # calib = find_acs(kspace) # Pad the arrays pads = [int(k/2) for k in kernel_size] adjs = [np.mod(k, 2) for k in kernel_size] kspace = np.pad(kspace, [(pd, pd) for pd in pads] + [(0, 0)], mode='constant') calib = np.pad(calib, [(pd, pd) for pd in pads] + [(0, 0)], mode='constant') # Find all the unique sampling patterns mask = np.abs(kspace[..., 0]) > 0 P = defaultdict(list) for idx in np.argwhere(~mask[tuple([slice(pd, -pd) for pd in pads])]): p0 = mask[tuple([slice(ii, ii+2*pd+adj) for ii, pd, adj in zip(idx, pads, adjs)])].flatten() if np.sum(p0) >= nnz: # only counts if it has enough samples P[tuple(p0.astype(int))].append(idx) # We need all overlapping patches from calibration data A = view_as_windows( calib, tuple(kernel_size) + (nc,)).reshape((-1,, nc,)) # Train and apply kernels if ret_weights: weights2return = defaultdict(list) ctr = np.ravel_multi_index([pd for pd in pads], dims=kernel_size) recon = np.zeros(kspace.shape, dtype=kspace.dtype) for key, holes in P.items(): # Used provided weights if we can, else compute them if weights is not None: W = weights[key] p0 = np.array(key, dtype=bool) else: # Get sampling pattern from key p0 = np.array(key, dtype=bool) # Train kernels S = A[:, p0, :].reshape(A.shape[0], -1) T = A[:, ctr, :] ShS = S.conj().T @ S ShT = S.conj().T @ T lamda0 = lamda*np.linalg.norm(ShS)/ShS.shape[0] W = np.linalg.solve( ShS + lamda0*np.eye(ShS.shape[0]), ShT) if ret_weights: weights2return[key] = W # Doesn't seem to be a big difference in speed? # Try gathering all sources and doing single matrix multiply # S = np.empty((len(holes), W.shape[0]), dtype=kspace.dtype) # targets = np.empty((kspace.ndim-1, len(holes)), dtype=int) # for jj, idx in enumerate(holes): # S[jj, :] = kspace[tuple([slice(ii, ii+2*pd+adj) for ii, pd, adj in zip(idx, pads, adjs)] + # [slice(None)])].reshape((-1, nc))[p0, :].flatten() # targets[:, jj] = [ii + pd for ii, pd in zip(idx, pads)] # recon = np.reshape(recon, (-1, nc)) # targets = np.ravel_multi_index(targets, dims=kspace.shape[:-1]) # recon[targets, :] = S @ W # recon = np.reshape(recon, kspace.shape) # Apply kernel to fill each hole for idx in holes: S = kspace[tuple([slice(ii, ii+2*pd+adj) for ii, pd, adj in zip(idx, pads, adjs)] + [slice(None)])].reshape((-1, nc))[p0, :].flatten() recon[tuple([ii + pd for ii, pd in zip(idx, pads)] + [slice(None)])] = S @ W # Add back in the measured voxels, put axis back where it goes recon[mask] += kspace[mask] recon = np.moveaxis( recon[tuple([slice(pd, -pd) for pd in pads] + [slice(None)])], -1, coil_axis) if ret_weights: return(recon, weights2return) return recon
if __name__ == '__main__': pass