import logging
import warnings
import gwcs
import numpy as np
from astropy.modeling.models import Identity, Scale, Shift
from astropy.stats import sigma_clipped_stats as scs
from astropy.utils.exceptions import AstropyUserWarning
from scipy.signal import find_peaks
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import dqflags
from jwst.adaptive_trace_model.bspline import bspline_fit
from jwst.assign_wcs.nirspec import nrs_ifu_wcs
from jwst.lib.pipe_utils import match_nans_and_flags
__all__ = [
"fit_2d_spline_trace",
"linear_oversample",
"fit_all_regions",
"oversample_flux",
"fit_and_oversample",
]
log = logging.getLogger(__name__)
def _get_weights_for_fit_scale(ratio, model_fit):
"""
Get weights for scaling the spline model to the fit data.
Sets weight to zero for outliers and invalid data points.
Parameters
----------
ratio : ndarray
Ratio between the fit data and the model evaluated at the data points.
model_fit : ndarray
The spline model evaluated at the data points.
Returns
-------
weights : ndarray
Array matching the ``ratio`` shape, containing weights for each
ratio data point.
"""
# Weights start off proportional to flux of the model
weights = model_fit.copy()
# Weights are zero for any pixel where the data was NaN
weights[~np.isfinite(ratio)] = 0
# Weights are zero for any pixel where the data or model was negative
weights[(ratio < 0) | (model_fit < 0)] = 0
# Identify the 5 largest weight points
order = np.argsort(weights)
largest_5 = (weights >= weights[order[-5]]) & (np.isfinite(ratio))
# Sigma-clipped mean and rms of these 5 ratios
mean, _, rms = scs(ratio[largest_5])
# Bad if over 2 sigma away
bad = np.abs(mean - ratio) > (2 * rms)
weights[bad] = 0
# Normalize weights
weights /= np.nansum(weights)
return weights
[docs]
def fit_2d_spline_trace(
flux,
alpha,
fit_scale=None,
lrange=50,
col_index=None,
require_ngood=10,
spline_bkpt=50,
space_ratio=1.2,
):
"""
Create a trace model from spline fits to a single slit/slice image.
Image must be oriented so that wavelengths are along x-axis. Each
column is fit separately, with a window to include nearby data.
Parameters
----------
flux : ndarray
Input 2D flux image to fit.
alpha : ndarray
Alpha coordinates for input flux.
fit_scale : ndarray, optional
Array of scale values to apply to the input flux before fitting.
lrange : int, optional
Local column range for data to include in the fit, to the
left and right of each input column.
col_index : iterable or None, optional
Iterable or generator that produces column index values to fit.
If provided, columns will be fit in the order specified.
If not provided, columns will be fit left to right.
require_ngood : int, optional
Minimum number of data points required to attempt a fit in a column.
spline_bkpt : int, optional
Number of spline breakpoints (knots).
space_ratio : float, optional
Maximum spacing ratio to allow fitting to continue. If
the tenth-largest spacing in the input ``xvec`` is larger
than the knot spacing by this ratio, then return None instead
of attempting to fit the data.
Returns
-------
splines : dict
Keys are column index numbers, values are `~scipy.interpolate.BSpline`.
If a spline model could not be fit, the column index number is not present.
scales : dict
Keys are column index numbers, values are floating point scales, to pair with
the returned models. If a spline model could not be fit, the column index
number is not present.
"""
# Define a fallback spline model, initialize to None
spline_model_save = None
# Set up the column fitting order if not provided
xsize = flux.shape[-1]
if col_index is None:
col_index = range(0, xsize, 1)
# Scale the flux for fitting
if fit_scale is not None:
scaled_flux = flux / fit_scale
else:
scaled_flux = flux
# Loop over columns in the slit/slice
splines = {}
scales = {}
for i in col_index:
col_flux = flux[:, i]
col_alpha = alpha[:, i]
ngood = np.sum(np.isfinite(col_flux) & np.isfinite(col_alpha))
if ngood <= require_ngood:
continue
# Get local alpha and flux values for fitting
lstart = np.max([i - lrange, 0])
lstop = np.min([i + lrange, xsize])
local_alpha = alpha[:, lstart:lstop]
local_data = scaled_flux[:, lstart:lstop]
# Trim to finite values
finite_values = np.isfinite(local_alpha) & np.isfinite(local_data)
local_alpha = local_alpha[finite_values]
local_data = local_data[finite_values]
# Sort by alpha
idx = np.argsort(local_alpha)
local_alpha = local_alpha[idx]
local_data = local_data[idx]
# Fit a bspline to the local data
try:
bspline = bspline_fit(
local_alpha,
local_data,
nbkpts=spline_bkpt,
wrapsig_low=2.5,
wrapsig_high=2.5,
wrapiter=3,
space_ratio=space_ratio,
verbose=False,
)
# If this routine could not get a fit (returned None) use the saved fit
if bspline is None:
spline_model = spline_model_save
else:
spline_model = bspline
except (ValueError, RuntimeError) as err:
log.warning(f"Spline fit failed at column {i}: {str(err)}")
spline_model = spline_model_save
# Check for a good model
if spline_model is None:
continue
# Store the spline model for the column
splines[i] = spline_model
spline_model_save = spline_model
# Evaluate the bspline at the valid input locations to determine
# a scale factor for the fit
idx = np.isfinite(col_alpha)
col_alpha = col_alpha[idx]
col_flux = col_flux[idx]
col_fit = spline_model(col_alpha)
# Determine the normalization by the weighted mean ratio between model and data
# Weights are based on the model so that we can reject outliers
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
ratio = col_flux / col_fit
# Weights start off proportional to flux
weights = _get_weights_for_fit_scale(ratio, col_fit)
wmeanratio = np.nansum(ratio * weights)
# Store the scale factor for the column
scales[i] = wmeanratio
return splines, scales
def _reindex(xmin, xmax, scale=2.0):
"""
Convert pixel positions on the old grid to oversampled positions.
For example, with oversample scale = 2, [0, 1, 2] goes to
old_x = [-0.25, 0.25, 0.75, 1.25, 1.75, 2.25], for
new_x = [0, 1, 2, 3, 4, 5].
With oversample scale = 3, [0, 1, 2] goes to
old_x = [-0.33, 0, 0.33, 0.67, 1, 1.33, 1.67, 2, 2.33], for
new_x = [0, 1, 2, 3, 4, 5, 6, 7, 8].
Parameters
----------
xmin : int
Minimum index.
xmax : int
Maximum index.
scale : float, optional
Oversample scaling factor.
Returns
-------
new_x : ndarray of int
Array of indices in the new grid.
old_x : ndarray of float
Array of coordinates in the old grid, corresponding to the new indices.
"""
# Indices in the new array
new_x = np.arange(xmin * scale, (xmax + 1) * scale, dtype=np.int32)
# Indices in the old array, scaled for new pixel spacing
# Also offset to center new coordinates on the old
old_x = new_x / scale - (scale - 1) / (scale * 2)
return new_x, old_x
def _is_compact_source(
alpha_slice, alpha_ptsource, native_dalpha, spline_bkpt, pad=3, require_npt=50
):
"""
Determine which pixels within a slice contain a compact source.
Parameters
----------
alpha_slice : ndarray
Alpha coordinates for the output slice. If oversampling is performed,
these should be the oversampled coordinates.
alpha_ptsource : ndarray
Array of alpha values for modeled flux that met the slope limit threshold.
native_dalpha : float
Approximate native pixel size in alpha, along the columns.
spline_bkpt : int
The number of breakpoints used in the spline modeling.
pad : int, optional
The number of pixels near peak data to include the spline fit for in
the output array.
require_npt : int, optional
The minimum required number of high-slope data points to consider any
pixels to be compact.
Returns
-------
is_compact : ndarray
Boolean array matching the shape of ``alpha_slice``, where True
indicates a pixel containing a compact source.
"""
is_compact = np.full(alpha_slice.shape, False)
# If there is not enough high slope data or no spline models were found,
# just return False for all data
if len(alpha_ptsource) < require_npt or spline_bkpt is None:
return is_compact
# Bin the alpha coordinates for the high slope locations
avec = np.arange(spline_bkpt) * native_dalpha / 2 - (native_dalpha * spline_bkpt / 4)
hist, edges = np.histogram(
alpha_ptsource,
bins=spline_bkpt,
range=(-native_dalpha * spline_bkpt / 4, native_dalpha * spline_bkpt / 4),
density=True,
)
hist = hist / np.nanmax(hist)
# Require peaks above some threshold
peak_indices, _ = find_peaks(hist, height=0.2)
amask = avec[peak_indices]
# Flag regions near the compact source, with some padding
for value in amask:
indx = (alpha_slice > value - pad * native_dalpha) & (
alpha_slice <= value + pad * native_dalpha
)
is_compact[indx] = True
return is_compact
def _trace_image(shape, spline_models, spline_scales, region_map, alpha, slope_limit=0.1, pad=3):
"""
Evaluate spline models at all pixels to generate a trace image.
The trace image will be NaN wherever a spline model was not fit and
wherever the source is not compact enough for the spline model
to be appropriate. The ``slope_limit`` parameter controls the decision
for compact source regions.
Parameters
----------
shape : tuple of int
Data shape for the output image.
spline_models : dict
Spline models to evaluate.
spline_scales : dict
Scaling factors for spline models.
region_map : ndarray
2D image matching shape, mapping valid region numbers.
alpha : ndarray
Alpha coordinates for all pixels marked as valid regions.
slope_limit : float, optional
The slope limit in the normalized model fits above which the spline
model is considered appropriate. Lower values will use spline fits
for fainter sources. If less than or equal to zero, the spline fits
will always be used.
pad : int, optional
The number of pixels near peak data to include the spline fit for in
the output array.
Returns
-------
trace_used : ndarray
2D image containing the scaled spline data fit evaluated at
the input alpha coordinates for compact source regions. Values are
NaN where no spline model was available and where the source
was below the slope limit.
full_trace: ndarray
2D image containing the scaled spline data fit evaluated at
all pixels. Values are NaN where no spline model was available.
"""
trace_used = np.full(shape, np.nan, dtype=np.float32)
full_trace = np.full(shape, np.nan, dtype=np.float32)
alpha_slice = np.full(shape, np.nan, dtype=np.float32)
trace_slice = np.full(shape, np.nan, dtype=np.float32)
spline_bkpt = None
for slnum in spline_models:
splines = spline_models[slnum]
scales = spline_scales[slnum]
alpha_slice[:] = np.nan
trace_slice[:] = np.nan
indx = region_map == slnum
alpha_slice[indx] = alpha[indx]
# Define a list that will hold all alpha values for this
# slice where the slope is high
alpha_ptsource = []
# loop over columns
for i in range(shape[-1]):
if i not in splines:
continue
# Evaluate the spline model for relevant data
col_alpha = alpha_slice[:, i]
valid_alpha = np.isfinite(col_alpha)
col_fit = splines[i](col_alpha[valid_alpha])
# Set the edges to NaN to avoid edge effects
col_fit[0] = np.nan
col_fit[-1] = np.nan
scaled_fit = scales[i] * col_fit
trace_slice[:, i][valid_alpha] = scaled_fit
# Get the slope of the model fit prior to scaling
model_slope = np.abs(np.diff(col_fit, prepend=0))
# Ensure boundaries don't look weird
model_slope[0] = 0
model_slope[-1] = 0
highslope = (np.where(model_slope > slope_limit))[0]
alpha_ptsource.append(col_alpha[valid_alpha][highslope])
# Get the number of spline breakpoints used from the first real model
if spline_bkpt is None:
spline_bkpt = len(np.unique(splines[i].t)) - 1
full_trace[indx] = trace_slice[indx]
if slope_limit <= 0:
# Always use the spline fit in this case
trace_used[indx] = trace_slice[indx]
else:
if len(alpha_ptsource) > 0:
alpha_ptsource = np.concatenate(alpha_ptsource)
native_dalpha = np.abs(np.nanmedian(np.diff(alpha_slice, axis=0)))
compact_locations = _is_compact_source(
alpha_slice, alpha_ptsource, native_dalpha, spline_bkpt, pad
)
trace_used[compact_locations] = trace_slice[compact_locations]
total_used = np.sum(compact_locations)
log.debug(
f"Using {total_used}/{np.sum(indx)} pixels from the spline model for slice {slnum}"
)
return trace_used, full_trace
def _linear_interp(col_y, col_flux, y_interp, edge_limit=0, preserve_nan=True):
"""
Perform a linear interpolation at one column.
Parameters
----------
col_y : ndarray
Y values in the original data for the column.
col_flux : ndarray
Flux values in the original data for the column.
y_interp : ndarray
Y values to interpolate to.
edge_limit : int, optional
If greater than zero, this many pixels at the edges of
the interpolated values will be set to NaN.
preserve_nan : bool, optional
If True, NaNs in the input will be preserved in the output.
Returns
-------
interpolated_flux : ndarray
Interpolated flux array.
"""
valid_data = np.isfinite(col_flux)
valid_y = np.isfinite(col_y)
valid_interp = valid_y & valid_data
interpolated_flux = np.interp(y_interp, col_y[valid_interp], col_flux[valid_interp])
if edge_limit >= 1:
interpolated_flux[0:edge_limit] = np.nan
interpolated_flux[-edge_limit:] = np.nan
# Check for NaNs in the input: they should be preserved in the output
if preserve_nan:
closest_pix = np.round(y_interp).astype(int)
is_nan = ~np.isfinite(col_flux[closest_pix])
interpolated_flux[is_nan] = np.nan
return interpolated_flux
[docs]
def linear_oversample(
data, region_map, oversample_factor, require_ngood, edge_limit=0, preserve_nan=True
):
"""
Oversample the input data with a linear interpolation.
Linear interpolation is performed for each column in each region in
the provided region map.
Parameters
----------
data : ndarray
Original data to oversample.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
oversample_factor : float
Scaling factor to oversample by.
require_ngood : int
Minimum number of pixels required in a column to perform an interpolation.
edge_limit : int, optional
If greater than zero, this many pixels at the edges of
the interpolated values will be set to NaN.
preserve_nan : bool, optional
If True, NaNs in the input will be preserved in the output.
Returns
-------
os_data : ndarray
The oversampled data array.
"""
ysize, xsize = data.shape
_, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
os_shape = (int(np.ceil(ysize * oversample_factor)), xsize)
os_data = np.full(os_shape, np.nan, dtype=np.float32)
data_slice = np.full_like(data, np.nan)
y_slice = np.full_like(data, np.nan)
slice_numbers = np.unique(region_map[region_map > 0])
for slnum in slice_numbers:
data_slice[:] = np.nan
y_slice[:] = np.nan
# Copy the relevant data for this slice into the holding arrays
indx = region_map == slnum
data_slice[indx] = data[indx]
y_slice[indx] = basey[indx]
for ii in range(xsize):
valid_data = np.isfinite(data_slice[:, ii])
ngood = np.sum(valid_data)
if ngood <= require_ngood:
continue
col_y = y_slice[:, ii]
col_flux = data_slice[:, ii]
valid_y = np.isfinite(col_y)
newy, oldy = _reindex(
int(col_y[valid_y].min()), int(col_y[valid_y].max()), scale=oversample_factor
)
os_data[newy, ii] = _linear_interp(
col_y, col_flux, oldy, edge_limit=edge_limit, preserve_nan=preserve_nan
)
return os_data
[docs]
def fit_all_regions(flux, alpha, region_map, signal_threshold, **fit_kwargs):
"""
Fit a trace model to all regions in the flux image.
Parameters
----------
flux : ndarray
The flux image to fit.
alpha : ndarray
Alpha coordinates for all flux values.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
signal_threshold : dict
Threshold values for each valid region in the region map. If
the median peak value across columns in the region is below this
threshold, a fit will not be attempted for that region.
**fit_kwargs
Keyword arguments to pass to the fitting routine (see `fit_2d_spline_trace`).
Returns
-------
spline_models : dict
Keys are region numbers, values are dicts containing a spline model for
each column index in the region. If a spline model could not be fit, the
column index number is not present.
scales : dict
Keys are region numbers, values are dicts containing a floating point scale
for each spline model, by column index number. If a spline model could not
be fit, the column index number is not present.
"""
# Arrays to reset with NaNs for each slice
data_slice = np.full_like(flux, np.nan)
alpha_slice = np.full_like(flux, np.nan)
spline_models = {}
spline_scales = {}
slice_numbers = np.unique(region_map[region_map > 0])
for slnum in slice_numbers:
log.info("Fitting slice %s", slnum)
# Reset holding arrays to NaN
data_slice[:] = np.nan
alpha_slice[:] = np.nan
# Copy the relevant data for this slice into the holding arrays
indx = region_map == slnum
data_slice[indx] = flux[indx]
alpha_slice[indx] = alpha[indx]
# A running sum in a given detector column (used for normalization)
runsum = np.nansum(data_slice, axis=0)
# Collapse the slice along Y to get max in each column
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
collapse = np.nanmax(data_slice, axis=0)
# Median column max across all columns
medcmax = np.nanmedian(collapse)
# Is medcmax over threshold? If so, do bspline for this slice.
dospline = False
if medcmax > signal_threshold[slnum]:
dospline = True
if dospline:
splines, scales = fit_2d_spline_trace(
data_slice, alpha_slice, fit_scale=runsum, **fit_kwargs
)
else:
splines = {}
scales = {}
spline_models[slnum] = splines
spline_scales[slnum] = scales
return spline_models, spline_scales
[docs]
def oversample_flux(
flux,
alpha,
region_map,
spline_models,
spline_scales,
oversample_factor,
alpha_os,
require_ngood=10,
slope_limit=0.1,
psf_optimal=False,
trim_ends=False,
pad=3,
):
"""
Oversample a flux image from spline models fit to the data.
For each column in each slice or slit in the region map:
1. Check if there are enough valid data points to proceed.
2. Compute oversampled coordinates corresponding to the input column.
3. Linearly interpolate flux values onto the oversampled column.
4. If a spline fit is available, evaluate it for the original column
coordinates.
5. Construct a residual between the spline fit and the original column.
data, then linearly interpolate the residual onto the oversampled
column.
6. Compute the slope of each column pixel as the difference between the
normalized spline model at that pixel and its immediate neighbor.
7. Evaluate the spline model at the oversampled coordinates.
The oversampled flux for each slice or slit is set from the spline flux
plus the interpolated residual, for pixels where the slope exceeds the
``slope_limit``. Otherwise, the flux is set to the linearly interpolated
value.
Parameters
----------
flux : ndarray
The flux image to fit.
alpha : ndarray
Alpha coordinates for all flux values.
region_map : ndarray of int
Map containing the slice or slit number for valid regions.
Values are >0 for pixels in valid regions, 0 otherwise.
spline_models : dict
Keys are region numbers, values are dicts containing a spline model for
each column index in the region. If a spline model could not be fit, the
column index number is not present.
spline_scales : dict
Keys are region numbers, values are dicts containing a floating point scale
for each spline model, by column index number. If a spline model could not
be fit, the column index number is not present.
oversample_factor : float
Scaling factor to oversample by.
alpha_os : ndarray
Alpha coordinates for the oversampled array, used to evaluate spline models
at every pixel.
require_ngood : int, optional
Minimum number of pixels required in a column to perform an interpolation.
slope_limit : float, optional
The slope limit in the normalized model fits above which the spline
model is considered appropriate. Lower values will use spline fits
for fainter sources. If less than or equal to zero, the spline fits
will always be used.
psf_optimal : bool, optional
If True, residual corrections to the spline model are not included
in the oversampled flux.
trim_ends : bool, optional
If True, the edges of the evaluated spline fit will be set to NaN.
pad : int, optional
The number of pixels near peak data to include the spline fit for in
the output array.
Returns
-------
flux_os : ndarray
The oversampled flux array, containing contributions from the evaluated
spline models, linear interpolations, and residual corrections.
trace_used : ndarray
A trace model, generated from the spline models evaluated at
pixels containing a compact source.
full_trace : ndarray
A trace model, generated from the spline models evaluated at
every pixel.
linear_flux : ndarray
The flux linearly interpolated onto the oversampled grid.
residual_flux : ndarray
Residuals between the spline modeled data and the original flux,
linearly interpolated onto the oversampled grid.
"""
ysize, xsize = flux.shape
_, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
# Oversampled flux array (linear and bspline to compare)
os_shape = (int(np.ceil(ysize * oversample_factor)), xsize)
flux_os_linear = np.full(os_shape, np.nan) # Linear interpolation
flux_os_bspline_full = np.full(os_shape, np.nan) # All bspline models
flux_os_bspline_use = np.full(os_shape, np.nan) # Actual bspline array applied
flux_os_residual = np.full(os_shape, np.nan) # Residual corrections
# Arrays to reset with NaNs for each slice
data_slice = np.full_like(flux, np.nan)
alpha_slice = np.full_like(flux, np.nan)
basey_slice = np.full_like(flux, np.nan)
alpha_os_slice = np.full(os_shape, np.nan)
reset_arrays = [data_slice, basey_slice, alpha_slice, alpha_os_slice]
# Edge limit for trimming ends
edge_limit = int(oversample_factor)
slice_numbers = np.unique(region_map[region_map > 0])
spline_bkpt = None
for slnum in slice_numbers:
# Reset holding arrays to NaN
for reset_array in reset_arrays:
reset_array[:] = np.nan
# Copy the relevant data for this slice into the holding arrays
indx = region_map == slnum
data_slice[indx] = flux[indx]
alpha_slice[indx] = alpha[indx]
basey_slice[indx] = basey[indx]
# Define a list that will hold all alpha values for this slice
# where the slope is high
alpha_ptsource = []
for ii in range(xsize):
# Are there sufficient values in this column to do anything?
valid_data = np.isfinite(data_slice[:, ii])
ngood = np.sum(valid_data)
if ngood <= require_ngood:
continue
# Get the relevant data for this column
col_y = basey_slice[:, ii]
col_alpha = alpha_slice[:, ii]
col_flux = data_slice[:, ii]
# newy is the resampled Y pixel indices in the expanded detector frame
# oldy is the resampled Y pixel indices in the original detector frame
valid_y = np.isfinite(col_y)
newy, oldy = _reindex(
int(col_y[valid_y].min()), int(col_y[valid_y].max()), scale=oversample_factor
)
# Default approach is to do linear interpolation
flux_os_linear[newy, ii] = _linear_interp(col_y, col_flux, oldy, edge_limit=edge_limit)
# Check for a spline fit for this column
if slnum not in spline_models or ii not in spline_models[slnum]:
continue
spline_model = spline_models[slnum][ii]
spline_scale = spline_scales[slnum][ii]
# Get the number of spline breakpoints used from the first real model
if spline_bkpt is None:
spline_bkpt = len(np.unique(spline_model.t)) - 1
# Get valid input locations and evaluate the spline
valid_alpha = np.isfinite(col_alpha)
col_fit = spline_model(col_alpha[valid_alpha])
scaled_fit = col_fit * spline_scale
# Construct the residual between spline fit and original data
# then oversample it to output frame by linear interpolation
residual = (col_flux[valid_alpha] - scaled_fit).astype(np.float32)
y_interp = col_y[valid_alpha]
valid_interp = np.isfinite(y_interp) & np.isfinite(residual)
interpval = np.interp(oldy, y_interp[valid_interp], residual[valid_interp])
if edge_limit >= 1:
interpval[0:edge_limit] = np.nan
interpval[-edge_limit:] = np.nan
flux_os_residual[newy, ii] = interpval
# What was the slope of the model fit prior to scaling?
model_slope = np.abs(np.diff(col_fit, prepend=0))
# Ensure boundaries don't look weird
if edge_limit >= 1:
model_slope[0:edge_limit] = 0
model_slope[-edge_limit:] = 0
# Add to our list of alpha values where the slope can be high for this slice
highslope = (np.where(model_slope > slope_limit))[0]
alpha_ptsource.append(col_alpha[valid_alpha][highslope])
# Store the oversampled alpha values to check against later
alpha_os_slice[newy, ii] = alpha_os[newy, ii]
# Evaluate the bspline at the oversampled alpha for this column
oversampled_fit = spline_model(alpha_os[newy, ii]) * spline_scale
if trim_ends and edge_limit >= 1:
oversampled_fit[0:edge_limit] = np.nan
oversampled_fit[-edge_limit:] = np.nan
flux_os_bspline_full[newy, ii] = oversampled_fit
# Now that our initial loop along the slice is done, we have a spline model everywhere
# Now look at our list of alpha values where model slopes were high to figure out
# where traces are and we actually want to use the spline model
if slope_limit <= 0:
# Always use the spline fit in this case
flux_os_bspline_use = flux_os_bspline_full
else:
if len(alpha_ptsource) > 0:
alpha_ptsource = np.concatenate(alpha_ptsource)
native_dalpha = np.abs(np.nanmedian(np.diff(alpha_slice, axis=0)))
compact_locations = _is_compact_source(
alpha_os_slice, alpha_ptsource, native_dalpha, spline_bkpt, pad
)
flux_os_bspline_use[compact_locations] = flux_os_bspline_full[compact_locations]
total_used = np.sum(compact_locations)
log.debug(
f"Using {total_used}/{np.sum(indx)} pixels from the spline model for slice {slnum}"
)
# Insert the bspline interpolated values into the final combined oversampled array,
# starting from the linearly interpolated array
flux_os = flux_os_linear
indx = np.where(np.isfinite(flux_os_bspline_use))
flux_os[indx] = flux_os_bspline_use[indx]
# Unless we're doing a specific psf optimal extraction, add in the residual fit
if not psf_optimal:
log.info("Applying complex scene corrections.")
# DRL- conflicted about this indx array
# Using only where flux_os_bspline_use is finite will trim the slice edges a bit
# because the spline can extend slightly beyond the linear interpolation which can
# be bad for sources on the edge. But requiring the residual to also be finite
# can result in bad performance when the residual correction was really NEEDED
# on the edge.
indx = np.where(np.isfinite(flux_os_bspline_use) & np.isfinite(flux_os_residual))
flux_os[indx] += flux_os_residual[indx]
return flux_os, flux_os_bspline_use, flux_os_bspline_full, flux_os_linear, flux_os_residual
def _set_fit_kwargs(detector, xsize):
"""
Set optional parameters for spline fits by detector.
Parameters
----------
detector : str
Detector name.
xsize : int
Input size for the data, along the dispersion axis. Used
to determine the column index order for spline fits.
Returns
-------
fit_kwargs : dict
Optional parameter settings to pass to the ``fit_all_regions``
function.
Raises
------
ValueError
If the input detector is not supported.
"""
# Empirical parameters for this mode
if detector.startswith("NRS"):
require_ngood = 15
spline_bkpt = 62
lrange = 50
# This factor of 1.6 was dialed based on inspection of the results
# as sampling gets progressively worse for NIRSpec detectors
space_ratio = 1.6
# Set up the column fitting order by detector
if detector == "NRS1":
# For NRS1, start on the left of detector since the tilt wrt pixels is greatest here
col_index = range(0, xsize, 1)
else:
# For NRS2, start on the right of detector since the tilt wrt pixels is greatest here
col_index = range(xsize - 1, -1, -1)
elif detector.startswith("MIR"):
require_ngood = 8
spline_bkpt = 36
lrange = 50
space_ratio = 1.2
# For MIRI fitting order, we need to start on the left and run to the middle,
# and then on the right to the middle in order to have the middle
# section not go too far beyond last good fit
col_index = np.concatenate(
[np.arange(0, xsize // 2 + 1), np.arange(xsize - 1, xsize // 2, -1)]
)
else:
raise ValueError("Unknown detector")
fit_kwargs = {
"lrange": lrange,
"col_index": col_index,
"require_ngood": require_ngood,
"spline_bkpt": spline_bkpt,
"space_ratio": space_ratio,
}
return fit_kwargs
def _set_oversample_kwargs(detector):
"""
Set optional parameters for oversampling by detector.
Parameters
----------
detector : str
Detector name.
Returns
-------
oversample_kwargs : dict
Optional parameter settings to pass to the ``oversample_flux``
function.
Raises
------
ValueError
If the input detector is not supported.
"""
if detector.startswith("NRS"):
# Trimming ends of the interpolation can help with bad extrapolations
pad = 2
trim_ends = True
elif detector.startswith("MIR"):
# Trimming ends is bad for MIRI, where dithers place point sources near the ends
pad = 3
trim_ends = False
else:
raise ValueError("Unknown detector")
oversample_kwargs = {"pad": pad, "trim_ends": trim_ends}
return oversample_kwargs
def _get_alpha_nrs_ifu(ifu_wcs, xsize, ysize):
"""
Get alpha coordinates for NIRSpec IFU corresponding to the original data array.
Parameters
----------
ifu_wcs : list of `~gwcs.WCS`
List of WCS objects, one per slice.
xsize : int
X-size for the data array.
ysize : int
Y-size for the data array.
Returns
-------
alpha_orig : ndarray
Alpha coordinates for the data array, with shape (ysize, xsize).
"""
alpha_orig = np.full((ysize, xsize), np.nan)
for slice_wcs in ifu_wcs:
x, y = gwcs.wcstools.grid_from_bounding_box(slice_wcs.bounding_box)
_, alpha, _ = slice_wcs.transform("detector", "slicer", x, y)
idx = y.astype(int), x.astype(int)
# Flip alpha so in same direction as increasing Y
alpha_orig[*idx] = -alpha
return alpha_orig
def _get_alpha_mir_mrs(wcs, xsize, ysize):
"""
Get alpha coordinates for MIRI MRS corresponding to the original data array.
Parameters
----------
wcs : `~gwcs.WCS`
WCS object.
xsize : int
X-size for the data array.
ysize : int
Y-size for the data array.
Returns
-------
alpha_orig : ndarray
Alpha coordinates for the data array, with shape (ysize, xsize).
"""
x, y = np.meshgrid(np.arange(xsize), np.arange(ysize))
det2ab = wcs.get_transform("detector", "alpha_beta")
alpha_orig, _, _ = det2ab(x, y)
return alpha_orig
def _get_oversampled_coords_nrs_ifu(ifu_wcs, x_os, y_os):
"""
Get alpha coordinates for NIRSpec IFU corresponding to the oversampled data array.
Parameters
----------
ifu_wcs : list of `~gwcs.WCS`
List of WCS objects, one per slice.
x_os : int
X-size for the oversampled data array.
y_os : int
Y-size for the oversampled data array.
Returns
-------
alpha_os : ndarray
Alpha coordinates for the data array, with shape (y_os, x_os).
wave_os : ndarray
Wavelength coordinates for the data array, with shape (y_os, x_os),
in um.
"""
os_shape = x_os.shape
alpha_os = np.full(os_shape, np.nan)
wave_os = np.full(os_shape, np.nan)
for slice_wcs in ifu_wcs:
bbox = slice_wcs.bounding_box
x_in_bounds = (x_os >= bbox[0][0]) & (x_os <= bbox[0][1])
y_in_bounds = (y_os >= bbox[1][0]) & (y_os <= bbox[1][1])
_, alpha, lam = slice_wcs.transform(
"detector",
"slicer",
x_os[x_in_bounds & y_in_bounds],
y_os[x_in_bounds & y_in_bounds],
)
alpha_os[x_in_bounds & y_in_bounds] = -alpha
# Store wavelength, convert to um
wave_os[x_in_bounds & y_in_bounds] = lam * 1e6
return alpha_os, wave_os
def _inflate_error(error_array, extname, oversample_factor):
"""
Inflate error or variance arrays to account for oversampling.
Errors are increased by a factor dependent on the oversampling ratio
in order to account for the covariance introduced by the oversampling.
The inflation factor was determined empirically for IFU data
by comparing the reported error of single-spaxel spectra and aperture-summed
spectra, following ``cube_build`` on an oversampled image.
Empirically, based on the RMS of the aperture-summed spectrum in a line-free
region of a stellar spectrum, the true SNR does not change much (< 4%) between
N=1 and N=2/3/4. In contrast the reported SNR increases by an amount well fit
by X = 0.23N + 0.77. I.e., X=1 for N=1, and X=1.46 for N=3. This does not account
for variations in individual pixels, but to first order, inflating by this X
factor when the oversampling is performed will produce data cubes in which the
SNR is mostly preserved accurately. Per-pixel errors in the oversampled
product are not accurately reported by the inflated errors, but the oversampled
product should be considered primarily an intermediate data product; the
errors in the resampled cube are more important.
Parameters
----------
error_array : ndarray
Error or variance image to inflate. Updated in place.
extname : {"err", "var_rnoise", "var_poisson", "var_flat"}
Extension name.
oversample_factor : float
The oversampling factor used.
"""
inflation_factor = 0.23 * oversample_factor + 0.77
if str(extname).lower().startswith("var"):
error_array *= inflation_factor**2
else:
error_array *= inflation_factor
def _update_wcs_nrs_ifu(wcs, map_pixels):
"""
Update a NIRSpec IFU WCS to include the oversampling transform.
Parameters
----------
wcs : `~gwcs.WCS`
The WCS object, including transforms for all slices.
May be either coordinate-based or slice-based.
map_pixels : `~astropy.modeling.models.Model`
Model that transforms from oversampled pixels to original detector
pixels, to be prepended to the WCS pipeline.
Returns
-------
wcs : `~gwcs.WCS`
The updated WCS. If the input WCS was coordinate-based,
then the new transform is prepended to the existing "coordinates"
transform. If it was slice-based, a new WCS pipeline is created
with "coordinates" as the input frame, containing the new transform.
"""
if "coordinates" in wcs.available_frames:
# coordinate-based WCS: update the existing transform with the new mapping
first_transform = wcs.pipeline[0].transform
wcs.pipeline[0].transform = map_pixels | first_transform
wcs.pipeline[0].transform.name = first_transform.name
wcs.pipeline[0].transform.inputs = first_transform.inputs
wcs.pipeline[0].transform.outputs = first_transform.outputs
# update bounding box limits in place
det2slicer_selector = wcs.pipeline[1].transform.selector
for slnum in range(30):
bb = det2slicer_selector[slnum + 1].bounding_box
bb[0], bb[1] = map_pixels.inverse(bb[0], bb[1])
else:
# slice-based WCS
map_pixels &= Identity(1)
map_pixels.name = "coord2det"
map_pixels.inputs = ("x", "y", "name")
map_pixels.outputs = ("x", "y", "name")
bbox = wcs.bounding_box
frame = gwcs.coordinate_frames.Frame2D(name="coordinates", axes_order=(0, 1))
wcs = gwcs.WCS([(frame, map_pixels), *wcs.pipeline])
# update bounding box limits
for slnum in range(30):
bb = bbox[slnum]
bb[0], bb[1], _ = map_pixels.inverse(bb[0], bb[1], slnum)
wcs.bounding_box = bbox
return wcs
def _update_wcs(wcs, map_pixels):
"""
Update a WCS to include the oversampling transform.
Appropriate to the MIRI MRS WCS or slit-like WCS objects, following ``extract_2d``.
Parameters
----------
wcs : `~gwcs.WCS`
The WCS object, including transforms for all slices.
map_pixels : `~astropy.modeling.models.Model`
Model that transforms from oversampled pixels to original detector
pixels, to be prepended to the WCS pipeline.
Returns
-------
wcs : `~gwcs.WCS`
A new WCS pipeline, with "coordinates" as the input frame, containing the
new transform.
"""
map_pixels.name = "coord2det"
map_pixels.inputs = ("x", "y")
map_pixels.outputs = ("x", "y")
frame = gwcs.coordinate_frames.Frame2D(name="coordinates", axes_order=(0, 1))
new_wcs = gwcs.WCS([(frame, map_pixels), *wcs.pipeline])
return new_wcs
def _intermediate_models(model, data_arrays):
"""
Make new datamodels for intermediate data arrays.
Parameters
----------
model : `~stdatamodels.jwst.datamodels.IFUImageModel`
The input datamodel. Metadata will be copied from it.
data_arrays : list of ndarray or None
Data arrays to save. If None, the model returned is also None.
Returns
-------
new_models : list of `~stdatamodels.jwst.datamodels.IFUImageModel` or None
A list of datamodels containing the input data arrays.
"""
new_models = []
for data in data_arrays:
if data is None:
new_model = None
else:
new_model = datamodels.IFUImageModel(data)
new_model.update(model)
new_models.append(new_model)
return new_models
[docs]
def fit_and_oversample(
model,
fit_threshold=10.0,
slope_limit=0.1,
psf_optimal=False,
oversample_factor=1.0,
return_intermediate_models=False,
):
"""
Fit a trace model and optionally oversample an IFU datamodel.
Parameters
----------
model : `~stdatamodels.jwst.datamodels.IFUImageModel`
The input datamodel, updated in place.
fit_threshold : float, optional
The signal threshold sigma for attempting spline fits within a slice region.
Lower values will create spline traces for more slices. If less than or
equal to 0, all slices will be fit.
slope_limit : float, optional
The normalized slope threshold for using the spline model in oversampled
data. Lower values will use the spline model for fainter sources. If less
than or equal to 0, the spline model will always be used.
psf_optimal : bool, optional
If True, residual corrections to the spline model are not included
in the oversampled flux. This option is generally appropriate for simple
isolated point sources only. If set, ``slope_limit`` and ``fit_threshold``
values are ignored and spline fits are attempted and used for all data.
oversample_factor : float, optional
If not 1.0, then the data will be oversampled by this factor.
return_intermediate_models : bool, optional
If True, additional image models will be returned, containing the full
spline model, the spline model as used for compact sources, the residual
model, and the linearly interpolated data.
Returns
-------
model : `~stdatamodels.jwst.datamodels.IFUImageModel`
The datamodel, updated with a trace image and optionally oversampled
arrays.
full_spline_model : `~stdatamodels.jwst.datamodels.IFUImageModel`, optional
The spline model evaluated at all pixels. Returned only if
``return_intermediate_models`` is True.
source_spline_model : `~stdatamodels.jwst.datamodels.IFUImageModel`, optional
The spline model evaluated at compact source locations only.
Returned only if ``return_intermediate_models`` is True.
linear_model : `~stdatamodels.jwst.datamodels.IFUImageModel` or None, optional
All data linearly interpolated onto the oversampled grid
Returned only if ``return_intermediate_models`` is True. Will be None if
``oversample_factor`` is 1.0.
residual_model : `~stdatamodels.jwst.datamodels.IFUImageModel` or None, optional
Residuals from the spline fit, linearly interpolated onto the oversampled grid
Returned only if ``return_intermediate_models`` is True. Will be None if
``oversample_factor`` is 1.0.
"""
# Check parameters
if psf_optimal:
log.info("Ignoring fit threshold and slope limit for psf_optimal=True")
fit_threshold = 0
slope_limit = 0
# Get input data coordinates
detector = model.meta.instrument.detector
ysize, xsize = model.data.shape
if detector.startswith("NRS"):
rotate = False
if isinstance(model, datamodels.IFUImageModel):
mode = "NRS_IFU"
wcs = nrs_ifu_wcs(model)
alpha_orig = _get_alpha_nrs_ifu(wcs, xsize, ysize)
# the region map is already stored in the datamodel
region_map = model.regions
else:
raise ValueError("Unsupported mode")
elif detector.startswith("MIR"):
rotate = True
if isinstance(model, datamodels.IFUImageModel):
mode = "MIR_MRS"
wcs = model.meta.wcs
alpha_orig = _get_alpha_mir_mrs(wcs, xsize, ysize)
# Region map is stored in the transform
det2ab_transform = wcs.get_transform("detector", "alpha_beta")
region_map = det2ab_transform.label_mapper.mapper.copy()
else:
raise ValueError("Unsupported mode")
else:
raise ValueError("Unknown detector")
# Rotate input data if needed
flux_orig = model.data
if rotate:
xsize, ysize = ysize, xsize
flux_orig = np.rot90(flux_orig)
alpha_orig = np.rot90(alpha_orig)
region_map = np.rot90(region_map)
# Set thresholding for the bspline fitting
# Do some statistics on the overall cal file
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=AstropyUserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
overall_mean, _, overall_rms = scs(flux_orig[region_map > 0])
overall_mean = 0 if ~np.isfinite(overall_mean) else overall_mean
overall_rms = 0 if ~np.isfinite(overall_rms) else overall_rms
# Need to ensure that the median pixel value isn't negative, because that causes chaos
# Subtract off that constant
if overall_mean < 0:
flux_orig = flux_orig - overall_mean
overall_mean = 0
# Define a per-slice analysis threshold (must be brighter than some level above background)
slice_numbers = np.unique(region_map[region_map > 0])
if fit_threshold <= 0:
# In this case, all slices should be fit, so make the threshold
# lower than any real signal
signal_threshold = dict.fromkeys(slice_numbers, -np.inf)
else:
if mode == "MIR_MRS":
# For MIRI MRS we need each channel to have its own threshold, particularly
# for Ch3/Ch4 since the sky is so much brighter in Ch4
signal_threshold = dict.fromkeys(slice_numbers, np.nan)
for channel in [100, 200, 300, 400]:
ch_data = (region_map >= channel) & (region_map < channel + 100)
if not np.any(ch_data):
continue
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=AstropyUserWarning)
ch_mean, _, ch_rms = scs(flux_orig[ch_data])
for slnum in slice_numbers:
if channel <= slnum < channel + 100:
signal_threshold[slnum] = ch_mean + fit_threshold * ch_rms
else:
# For NIRSpec IFU, all regions have the same threshold
threshold = overall_mean + fit_threshold * overall_rms
signal_threshold = dict.fromkeys(slice_numbers, threshold)
# Fit spline models to all regions
fit_kwargs = _set_fit_kwargs(detector, xsize)
spline_models, spline_scales = fit_all_regions(
flux_orig, alpha_orig, region_map, signal_threshold, **fit_kwargs
)
# If oversampling is not needed, evaluate the spline models to create the
# trace image, store it in the model, and return.
# In the future, it might be useful to update the SCI extension here for the
# psf_optimal=True case, even when oversample=1, but for now, we will leave
# data unmodified.
oversample_kwargs = _set_oversample_kwargs(detector)
if oversample_factor == 1:
trace_used, full_trace = _trace_image(
flux_orig.shape,
spline_models,
spline_scales,
region_map,
alpha_orig,
slope_limit=slope_limit,
pad=oversample_kwargs["pad"],
)
if rotate:
trace_used = np.rot90(trace_used, k=-1)
full_trace = np.rot90(full_trace, k=-1)
model.trace_model = trace_used
if return_intermediate_models:
new_models = _intermediate_models(model, [full_trace, trace_used, None, None])
return model, *new_models
else:
return model
# Oversampled array size
os_shape = (int(np.ceil(ysize * oversample_factor)), xsize)
x_os = np.full(os_shape, np.nan)
y_os = np.full(os_shape, np.nan)
# Pre-compute coordinates for the new data size
log.info("Computing oversampled coordinates")
basex, basey = np.meshgrid(np.arange(xsize), np.arange(ysize))
newy, oldy = _reindex(0, ysize - 1, scale=oversample_factor)
y_os[:, :] = oldy[:, None]
x_os[:, :] = basex[oldy.astype(int), :]
if mode == "NRS_IFU":
alpha_os, wave_os = _get_oversampled_coords_nrs_ifu(wcs, x_os, y_os)
else:
# Because MIRI was rotated the indexing in the non-rotated frame,
# the input coordinates need to be adjusted slightly
det2ab = model.meta.wcs.get_transform("detector", "alpha_beta")
alpha_os, _, wave_os = det2ab(ysize - y_os - 1, x_os)
log.info("Oversampling the flux array from the fit trace model")
flux_os, trace_used, full_trace, linear, residual = oversample_flux(
flux_orig,
alpha_orig,
region_map,
spline_models,
spline_scales,
oversample_factor,
alpha_os,
slope_limit=slope_limit,
psf_optimal=psf_optimal,
require_ngood=fit_kwargs["require_ngood"],
**oversample_kwargs,
)
log.info("Oversampling error and DQ arrays")
error_extensions = ["err", "var_rnoise", "var_poisson", "var_flat"]
errors = {}
for extname in error_extensions:
if model.hasattr(extname):
errors[extname] = getattr(model, extname)
if rotate:
errors[extname] = np.rot90(errors[extname])
dq = model.dq
if rotate:
dq = np.rot90(dq)
# Nearest pixel interpolation for the dq and regions array
closest_pix = (np.round(y_os).astype(int), np.round(x_os).astype(int))
dq_os = dq[*closest_pix]
regions_os = region_map[*closest_pix]
# Update the DQ image for pixels that used to be NaN, now replaced by spline interpolation.
# Remove the DO_NOT_USE flag, add FLUX_ESTIMATED
is_estimated = ~np.isnan(flux_os) & ((dq_os & dqflags.pixel["DO_NOT_USE"]) > 0)
dq_os[is_estimated] ^= dqflags.pixel["DO_NOT_USE"]
dq_os[is_estimated] |= dqflags.pixel["FLUX_ESTIMATED"]
# Simple linear oversample for the error arrays
errors_os = {}
for extname, error_array in errors.items():
error_os = linear_oversample(
error_array,
region_map,
oversample_factor,
fit_kwargs["require_ngood"],
edge_limit=0,
preserve_nan=False,
)
# Restore NaNs from the input, except at the estimated locations
is_nan = ~np.isfinite(error_array[closest_pix])
error_os[is_nan & ~is_estimated] = np.nan
# Inflate the errors to account for oversampling covariance
_inflate_error(error_os, extname, oversample_factor)
errors_os[extname] = error_os
# Update the wcs for new pixel scale
scale_and_shift = Scale(1 / oversample_factor) | Shift(
-(oversample_factor - 1) / (oversample_factor * 2)
)
if mode == "NRS_IFU":
map_pixels = Identity(1) & scale_and_shift
model.meta.wcs = _update_wcs_nrs_ifu(model.meta.wcs, map_pixels)
else:
# MIRI
map_pixels = scale_and_shift & Identity(1)
model.meta.wcs = _update_wcs(model.meta.wcs, map_pixels)
# If needed, undo all of our rotations before passing back the arrays
if rotate:
flux_os = np.rot90(flux_os, k=-1)
dq_os = np.rot90(dq_os, k=-1)
wave_os = np.rot90(wave_os, k=-1)
regions_os = np.rot90(regions_os, k=-1)
trace_used = np.rot90(trace_used, k=-1)
full_trace = np.rot90(full_trace, k=-1)
linear = np.rot90(linear, k=-1)
residual = np.rot90(residual, k=-1)
for extname, error_array in errors_os.items():
errors_os[extname] = np.rot90(error_array, k=-1)
# Update the model with the oversampled arrays
model.data = flux_os
model.dq = dq_os
model.wavelength = wave_os
model.trace_model = trace_used
for extname, error_array in errors_os.items():
setattr(model, extname, error_array)
if isinstance(model, datamodels.IFUImageModel):
model.regions = regions_os
# Remove some extra arrays if present: no longer needed
extras = [
"area",
"pathloss_point",
"pathloss_uniform",
"zeroframe",
]
for name in extras:
if model.hasattr(name):
delattr(model, name)
# Make sure NaNs and DO_NOT_USE flags match in all extensions
match_nans_and_flags(model)
# Return intermediate models if needed
if return_intermediate_models:
new_models = _intermediate_models(model, [full_trace, trace_used, linear, residual])
return model, *new_models
else:
return model