import logging
import math
import warnings
import numpy as np
from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels.dqflags import pixel
from jwst.assign_wcs.util import create_grism_bbox
from jwst.lib.reffile_utils import get_subarray_model
log = logging.getLogger(__name__)
__all__ = ["subtract_wfss_bkg", "ScalingFactorComputer"]
[docs]
def subtract_wfss_bkg(
model,
bkg_filename,
wl_range_name,
mmag_extract=None,
user_mask=None,
rescaler_kwargs=None,
):
"""
Scale and subtract a background reference image from WFSS/GRISM data.
Parameters
----------
model : `~stdatamodels.jwst.datamodels.ImageModel`
Copy of input target exposure data model.
bkg_filename : str
Name of master background file for WFSS/GRISM.
wl_range_name : str
Name of wavelength range reference file.
mmag_extract : float or None, optional
Minimum AB mag of grism objects to extract. Default: None
user_mask : ndarray[bool] or None, optional
User-supplied boolean source mask. `True` for background,
`False` for pixels that are part of sources. If provided, will supersede the auto-generated
mask from the source catalog, and ``model.meta.source_catalog`` will be ignored entirely.
Default: None
rescaler_kwargs : dict or None, optional
Keyword arguments to pass to `ScalingFactorComputer`. Default: None
Returns
-------
result : `~stdatamodels.jwst.datamodels.ImageModel`
Background-subtracted target data model.
"""
bkg_ref = datamodels.open(bkg_filename)
bkg_ref = get_subarray_model(model, bkg_ref)
# get the dispersion axis
try:
dispaxis = model.meta.wcsinfo.dispersion_direction
except AttributeError:
log.warning(
"Dispersion axis not found in input science image metadata. "
"Variance stopping criterion will have no effect for iterative "
"outlier rejection (will run until maxiter)."
)
dispaxis = None
if rescaler_kwargs is None:
rescaler_kwargs = {}
rescaler_kwargs["dispersion_axis"] = dispaxis
# get the source catalog for masking
if user_mask is None:
if getattr(model.meta, "source_catalog", None) is not None:
# Create a mask from the source catalog, True where there are no sources,
# i.e. in regions we can use as background.
bkg_mask = _mask_from_source_cat(model, wl_range_name, mmag_extract)
log.warning("No source_catalog found in input.meta, and custom mask not specified. ")
log.warning("No sources will be masked for background scaling.")
if not _sufficient_background_pixels(model.dq, bkg_mask, bkg_ref.data):
log.warning("Not enough background pixels to work with.")
log.warning("Step will be marked FAILED.")
# Save the mask in expected data type for the datamodel and set
# other keywords appropriately for this case
model.mask = bkg_mask.astype(np.uint32)
model.meta.background.scaling_factor = 0.0
model.meta.cal_step.bkg_subtract = "FAILED"
bkg_ref.close()
return model
else:
log.warning("No source_catalog found in input.meta. Setting all pixels as background.")
bkg_mask = np.ones(model.data.shape, dtype=bool)
else:
log.info("Using user-supplied source mask for background scaling.")
# we want a more generous criterion for sufficient background pixels here,
# since the user is explicitly specifying the mask.
# Assume bkg_ref is all good pixels, and set minimum fraction to zero
# Then this is effectively just dq array & user mask constraints
if not _sufficient_background_pixels(
model.dq, user_mask, np.ones_like(user_mask), min_pixfrac=0.0
):
log.warning("No background pixels found in user-supplied mask.")
log.warning("Step will be marked FAILED.")
# Save the mask in expected data type for the datamodel and set
# other keywords appropriately for this case
model.mask = user_mask.astype(np.uint32)
model.meta.background.scaling_factor = 0.0
model.meta.cal_step.bkg_subtract = "FAILED"
bkg_ref.close()
return model
bkg_mask = user_mask.astype(bool)
# save the mask in expected data type for the datamodel
model.mask = bkg_mask.astype(np.uint32)
# compute scaling factor for the reference background image
log.info("Starting iterative outlier rejection for background subtraction.")
rescaler = ScalingFactorComputer(**rescaler_kwargs)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=RuntimeWarning, message="All-NaN slice encountered"
)
# copy to avoid propagating NaNs from iterative clipping into final product
sci = model.data.copy()
var = model.err.copy() ** 2
bkg = bkg_ref.data.copy()
factor, _ = rescaler(sci, bkg, var, mask=~bkg_mask)
# check for bad value of factor
if not np.isfinite(factor):
log.warning(
"Could not determine a finite scaling factor between reference background and data."
" Step will be marked FAILED."
)
model.meta.background.scaling_factor = 0.0
model.meta.cal_step.bkg_subtract = "FAILED"
bkg_ref.close()
return model
# extract the derived factor and apply it to the unmasked, non-outlier-rejected data
subtract_this = factor * bkg_ref.data
model.data = model.data - subtract_this
model.dq = np.bitwise_or(model.dq, bkg_ref.dq)
model.meta.background.scaling_factor = factor
log.info(f"Average of scaled background image = {np.nanmean(subtract_this):.3e}")
log.info(f"Scaling factor = {factor:.5e}")
bkg_ref.close()
return model
[docs]
class ScalingFactorComputer:
"""
Handle computation of scaling factor.
Parameters
----------
p : float, optional
Percentile for sigma clipping on both low and high ends per iteration, default 1.0.
For example, with ``p=2.0``, the middle 96% of the data is kept.
maxiter : int, optional
Maximum number of iterations for outlier rejection. Default 5.
delta_rms_thresh : float, optional
Stopping criterion for outlier rejection; stops when the rms residuals
change by less than this fractional threshold in a single iteration.
For example, assuming ``delta_rms_thresh=0.1`` and a residual RMS of 100
in iteration 1, the iteration will stop if the RMS residual in iteration
2 is greater than 90.
Default 0.0, i.e., ignore this and only stop at ``maxiter``.
dispersion_axis : int, optional
The index to select the along-dispersion axis. Used to compute the RMS
residual, so must be set if ``rms_thresh > 0``. Default None.
"""
def __init__(self, p=1.0, maxiter=5, delta_rms_thresh=0, dispersion_axis=None):
if (delta_rms_thresh > 0) and (dispersion_axis not in [1, 2]):
msg = (
f"Unrecognized dispersion axis {dispersion_axis}. "
"Dispersion axis must be specified if delta_rms_thresh "
"is used as a stopping criterion."
)
raise ValueError(msg)
self.p = p
self.maxiter = maxiter
self.delta_rms_thresh = delta_rms_thresh
self.dispersion_axis = dispersion_axis
[docs]
def __call__(self, sci, bkg, var, mask=None):
"""
Call function for the class.
Parameters
----------
sci : ndarray
The science data.
bkg : ndarray
The reference background model.
var : ndarray
Total variance (error squared) of the science data.
mask : ndarray[bool], optional
Initial mask to be applied to the data, True where bad.
Typically, this would mask out the real sources in the data.
Returns
-------
float
Scaling factor that minimizes sci - factor*bkg,
taking into account residuals and outliers.
ndarray[bool]
Outlier mask generated by the iterative clipping procedure.
"""
if mask is None:
mask = np.zeros(sci.shape, dtype="bool")
self._update_nans(sci, bkg, var, mask)
# iteratively reject more and more outliers
i = 0
last_rms_resid = np.inf
while i < self.maxiter:
# compute the factor that minimizes the residuals
factor = self.err_weighted_mean(sci, bkg, var)
sci_sub = sci - factor * bkg
# Check fractional improvement stopping criterion before incrementing.
# Note this never passes in iteration 0 because last_rms_resid is inf.
if self.delta_rms_thresh > 0:
rms_resid = self._compute_rms_residual(sci_sub)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=RuntimeWarning,
message="invalid value encountered in scalar divide",
)
fractional_diff = (last_rms_resid - rms_resid) / last_rms_resid
if fractional_diff < self.delta_rms_thresh:
msg = (
f"Stopping at iteration {i}; too little improvement "
"since last iteration (hit delta_rms_thresh)."
)
log.info(msg)
break
last_rms_resid = rms_resid
i += 1
# Reject outliers based on residual between sci and bkg.
# Updating the sci, var, and bkg nan values means that
# they are ignored by nanpercentile in the next iteration
limits = np.nanpercentile(sci_sub, (self.p, 100 - self.p))
mask += np.logical_or(sci_sub < limits[0], sci_sub > limits[1])
self._update_nans(sci, bkg, var, mask)
if i >= self.maxiter:
log.info(f"Stopped at maxiter ({i}).")
self._iters_run_last_call = i
return self.err_weighted_mean(sci, bkg, var), mask
[docs]
def err_weighted_mean(self, sci, bkg, var):
"""
Remove any var=0 values, which can happen for real data.
Parameters
----------
sci : ndarray
The science data.
bkg : ndarray
The reference background model.
var : ndarray
Total variance (error squared) of the science data.
Returns
-------
ndarray
New array with the weighted sum of array elements
"""
mask = var == 0
self._update_nans(sci, bkg, var, mask)
return np.nansum(sci * bkg / var, dtype="f8") / np.nansum(bkg * bkg / var, dtype="f8")
def _update_nans(self, sci, bkg, var, mask):
sci[mask] = np.nan
bkg[mask] = np.nan
var[mask] = np.nan
def _compute_rms_residual(self, sci_sub):
"""
Calculate the background-subtracted RMS along the dispersion axis.
This axis is found by taking the median profile of the image collapsed
along the cross-dispersion axis.
Note: meta.wcsinfo.dispersion_axis is 1-indexed coming out of assign_wcs, i.e., in [1,2].
Parameters
----------
sci_sub : ndarray
Scaled down science data.
Returns
-------
float
Root mean square
"""
collapsing_axis = int(self.dispersion_axis - 1)
sci_sub_profile = np.nanmedian(sci_sub, axis=collapsing_axis)
return np.sqrt(np.nanmean(sci_sub_profile**2, dtype="f8"))
def _sufficient_background_pixels(dq_array, bkg_mask, bkg, min_pixfrac=0.05):
"""
Count number of good pixels for background use.
Check DQ flags of pixels selected for bkg use - XOR the DQ values with
the DO_NOT_USE flag to flip the DO_NOT_USE bit. Then count the number
of pixels that AND with the DO_NOT_USE flag, i.e., initially did not have
the DO_NOT_USE bit set.
Parameters
----------
dq_array : ndarray
Subarray input DQ array
bkg_mask : ndarray
Boolean background mask. True where background is GOOD.
bkg : ndarray
Background data array
min_pixfrac : float, optional
Minimum fraction of good pixels required for background use.
Default is 0.05 (5%).
Returns
-------
int or array of int
The number of good pixels for background use.
"""
good_bkg = bkg != 0
good_mask = np.logical_and(bkg_mask, good_bkg)
n_good = np.count_nonzero((dq_array[good_mask] ^ pixel["DO_NOT_USE"]) & pixel["DO_NOT_USE"])
min_pixels = int(min_pixfrac * dq_array.size)
return n_good > min_pixels
def _mask_from_source_cat(input_model, wl_range_name, mmag_extract=None):
"""
Create a mask that is False within bounding boxes of sources.
Parameters
----------
input_model : `~stdatamodels.jwst.datamodels.ImageModel`
Input target exposure data model
wl_range_name : str
Name of the wavelengthrange reference file
mmag_extract : float
Minimum AB mag of grism objects to extract
Returns
-------
bkg_mask : ndarray
Boolean mask: True for background, False for pixels that are
inside at least one of the source regions defined in the source
catalog.
"""
shape = input_model.data.shape
bkg_mask = np.ones(shape, dtype=bool)
reference_files = {"wavelengthrange": wl_range_name}
grism_obj_list = create_grism_bbox(input_model, reference_files, mmag_extract)
for obj in grism_obj_list:
order_bounding = obj.order_bounding
for order in order_bounding.keys():
((ymin, ymax), (xmin, xmax)) = order_bounding[order]
xmin = int(math.floor(xmin))
xmax = int(math.ceil(xmax)) + 1 # convert to slice limit
ymin = int(math.floor(ymin))
ymax = int(math.ceil(ymax)) + 1
xmin = max(xmin, 0)
xmax = min(xmax, shape[-1])
ymin = max(ymin, 0)
ymax = min(ymax, shape[-2])
bkg_mask[..., ymin:ymax, xmin:xmax] = False
return bkg_mask