#!/usr/bin/env python

# Script that performs B1 field inhomogeneity correction for a DWI volume series
# Bias field is estimated using the mean b=0 image, and subsequently used to correct all volumes

# Make the corresponding MRtrix3 Python libraries available
import inspect, os, sys
lib_folder = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))), os.pardir, 'lib'))
if not os.path.isdir(lib_folder):
  sys.stderr.write('Unable to locate MRtrix3 Python libraries')
  sys.exit(1)
sys.path.insert(0, lib_folder)

from distutils.spawn import find_executable
from mrtrix3 import app, fsl, image, path, run

opt_N4BiasFieldCorrection = {
    's': ('4','shrink-factor applied to spatial dimensions'),
    'b':('[100,3]','[initial mesh resolution in mm, spline order] This value is optimised for human adult data and needs to be adjusted for rodent data.'),
    'c':('[1000,0.0]', '[numberOfIterations,convergenceThreshold]')}

app.init('Robert E. Smith (robert.smith@florey.edu.au)',
         'Perform B1 field inhomogeneity correction for a DWI volume series')
app.cmdline.addCitation('If using -fast option', 'Zhang, Y.; Brady, M. & Smith, S. Segmentation of brain MR images through a hidden Markov random field model and the expectation-maximization algorithm. IEEE Transactions on Medical Imaging, 2001, 20, 45-57', True)
app.cmdline.addCitation('If using -fast option', 'Smith, S. M.; Jenkinson, M.; Woolrich, M. W.; Beckmann, C. F.; Behrens, T. E.; Johansen-Berg, H.; Bannister, P. R.; De Luca, M.; Drobnjak, I.; Flitney, D. E.; Niazy, R. K.; Saunders, J.; Vickers, J.; Zhang, Y.; De Stefano, N.; Brady, J. M. & Matthews, P. M. Advances in functional and structural MR image analysis and implementation as FSL. NeuroImage, 2004, 23, S208-S219', True)
app.cmdline.addCitation('If using -ants option', 'Tustison, N.; Avants, B.; Cook, P.; Zheng, Y.; Egan, A.; Yushkevich, P. & Gee, J. N4ITK: Improved N3 Bias Correction. IEEE Transactions on Medical Imaging, 2010, 29, 1310-1320', True)
antsoptions = app.cmdline.add_argument_group('Options for ANTS N4BiasFieldCorrection')
for key in sorted(opt_N4BiasFieldCorrection):
  antsoptions.add_argument('-ants.'+key, metavar=opt_N4BiasFieldCorrection[key][0], help='N4BiasFieldCorrection option -%s. %s' % (key,opt_N4BiasFieldCorrection[key][1]))
  app.cmdline.flagMutuallyExclusiveOptions( [ '-ants.'+key, 'fsl' ] )
app.cmdline.add_argument('input',  help='The input image series to be corrected')
app.cmdline.add_argument('output', help='The output corrected image series')
options = app.cmdline.add_argument_group('Options for the dwibiascorrect script')
options.add_argument('-mask', help='Manually provide a mask image for bias field estimation')
options.add_argument('-bias', help='Output the estimated bias field')
options.add_argument('-ants', action='store_true', help='Use ANTS N4 to estimate the inhomogeneity field')
options.add_argument('-fsl', action='store_true', help='Use FSL FAST to estimate the inhomogeneity field')
app.cmdline.flagMutuallyExclusiveOptions( [ 'ants', 'fsl' ] )
options.add_argument('-grad', help='Pass the diffusion gradient table in MRtrix format')
options.add_argument('-fslgrad', nargs=2, metavar=('bvecs', 'bvals'), help='Pass the diffusion gradient table in FSL bvecs/bvals format')
app.cmdline.flagMutuallyExclusiveOptions( [ 'grad', 'fslgrad' ] )
app.parse()

if app.args.fsl:

  if app.isWindows():
    app.error('Script cannot run using FSL on Windows due to FSL dependency')

  fsl_path = os.environ.get('FSLDIR', '')
  if not fsl_path:
    app.error('Environment variable FSLDIR is not set; please run appropriate FSL configuration script')

  fast_cmd = fsl.exeName('fast')

  app.warn('Use of -fsl option in dwibiascorrect script is discouraged due to its strong dependence ' + \
           'on initial brain masking, and its inability to correct voxels outside of this mask.' + \
           'Use of the -ants option is recommended for quantitative DWI analyses.')

elif app.args.ants:

  if not find_executable('N4BiasFieldCorrection'):
    app.error('Could not find ANTS program N4BiasFieldCorrection; please check installation')

  for key in sorted(opt_N4BiasFieldCorrection):
    if hasattr(app.args, 'ants.'+key):
      val = getattr(app.args, 'ants.'+key)
      if val is not None:
        opt_N4BiasFieldCorrection[key] = (val, 'user defined')
  ants_options = ' '.join(['-%s %s' %(k, v[0]) for k, v in opt_N4BiasFieldCorrection.items()])
else:
  app.error('No bias field estimation algorithm specified')

grad_import_option = ''
if app.args.grad:
  grad_import_option = ' -grad ' + path.fromUser(app.args.grad, True)
elif app.args.fslgrad:
  grad_import_option = ' -fslgrad ' + path.fromUser(app.args.fslgrad[0], True) + ' ' + path.fromUser(app.args.fslgrad[1], True)

app.checkOutputPath(app.args.output)
app.checkOutputPath(app.args.bias)

app.makeTempDir()

run.command('mrconvert ' + path.fromUser(app.args.input, True) + ' ' + path.toTemp('in.mif', True) + grad_import_option)
if app.args.mask:
  run.command('mrconvert ' + path.fromUser(app.args.mask, True) + ' ' + path.toTemp('mask.mif', True))

app.gotoTempDir()

# Make sure it's actually a DWI that's been passed
dwi_header = image.Header('in.mif')
if len(dwi_header.size()) != 4:
  app.error('Input image must be a 4D image')
if 'dw_scheme' not in dwi_header.keyval():
  app.error('No valid DW gradient scheme provided or present in image header')
if len(dwi_header.keyval()['dw_scheme']) != dwi_header.size()[3]:
  app.error('DW gradient scheme contains different number of entries (' + str(len(dwi_header.keyval()['dw_scheme'])) + ' to number of volumes in DWIs (' + dwi_header.size()[3] + ')')

# Generate a brain mask if required, or check the mask if provided
if app.args.mask:
  if image.Header('mask.mif').size()[:3] != dwi_header.size()[:3]:
    app.error('Provided mask image does not match input DWI')
else:
  run.command('dwi2mask in.mif mask.mif')

# Generate a mean b=0 image
run.command('dwiextract in.mif - -bzero | mrmath - mean mean_bzero.mif -axis 3')

if app.args.fsl:

  # FAST doesn't accept a mask input; therefore need to explicitly mask the input image
  run.command('mrcalc mean_bzero.mif mask.mif -mult - | mrconvert - mean_bzero_masked.nii -strides -1,+2,+3')
  run.command(fast_cmd + ' -t 2 -o fast -n 3 -b mean_bzero_masked.nii')
  bias_path = fsl.findImage('fast_bias')

elif app.args.ants:

  # If the mask image was provided manually, and doesn't match the input image perfectly
  #   (i.e. also transform and voxel sizes), N4 will fail
  if app.args.mask:
    if not image.match('mean_bzero.mif', 'mask.mif'):
      app.error('Input mask header does not perfectly match DWI as required by N4')

  # Use the brain mask as a weights image rather than a mask; means that voxels at the edge of the mask
  #   will have a smoothly-varying bias field correction applied, rather than multiplying by 1.0 outside the mask
  run.command('mrconvert mean_bzero.mif mean_bzero.nii -strides +1,+2,+3')
  run.command('mrconvert mask.mif mask.nii -strides +1,+2,+3')
  init_bias_path = 'init_bias.nii'
  corrected_path = 'corrected.nii'
  run.command('N4BiasFieldCorrection -d 3 -i mean_bzero.nii -w mask.nii -o [' + corrected_path + ',' + init_bias_path + '] ' + ants_options)

  # N4 can introduce large differences between subjects via a global scaling of the bias field
  # Estimate this scaling based on the total integral of the pre- and post-correction images within the brain mask
  input_integral  = float(run.command('mrcalc mean_bzero.mif mask.mif -mult - | mrmath - sum - -axis 0 | mrmath - sum - -axis 1 | mrmath - sum - -axis 2 | mrdump -')[0])
  output_integral = float(run.command('mrcalc ' + corrected_path + ' mask.mif -mult - | mrmath - sum - -axis 0 | mrmath - sum - -axis 1 | mrmath - sum - -axis 2 | mrdump -')[0])
  app.var(input_integral, output_integral)
  bias_path = 'bias.mif'
  run.command('mrcalc ' + init_bias_path + ' ' + str(output_integral / input_integral) + ' -mult ' + bias_path)



run.command('mrcalc in.mif ' + bias_path + ' -div result.mif')
run.command('mrconvert result.mif ' + path.fromUser(app.args.output, True) + (' -force' if app.forceOverwrite else ''))
if app.args.bias:
  run.command('mrconvert ' + bias_path + ' ' + path.fromUser(app.args.bias, True) + (' -force' if app.forceOverwrite else ''))
app.complete()
