Source code for jwst.assign_mtwcs.moving_target_wcs

"""Functions for adjusting the WCS of a moving target exposure."""

import logging
from copy import deepcopy

import numpy as np
from astropy.modeling.models import Identity, Shift
from gwcs import WCS
from gwcs import coordinate_frames as cf
from stdatamodels.jwst import datamodels

from jwst.assign_wcs.util import update_s_region_imaging
from jwst.datamodels import ModelLibrary
from jwst.lib.exposure_types import IMAGING_TYPES
from jwst.stpipe.utilities import record_step_status

log = logging.getLogger(__name__)

__all__ = ["assign_moving_target_wcs", "add_mt_frame"]


[docs] def assign_moving_target_wcs(input_models): """ Adjust the WCS of a moving target exposure. Computes the average RA and DEC of a moving target in all exposures in an association and adds a step to each of the WCS pipelines to allow aligning the exposures to the average location of the target. Parameters ---------- input_models : `~jwst.datamodels.library.ModelLibrary` A collection of data models. Returns ------- `~jwst.datamodels.library.ModelLibrary` The modified data models. """ if not isinstance(input_models, ModelLibrary): raise TypeError(f"Expected a ModelLibrary object, not {type(input_models)}") science_indices = input_models.indices_for_exptype("science") mt_ra = np.full(len(science_indices), np.nan) mt_dec = np.full(len(science_indices), np.nan) mt_valid = True with input_models: # loop over only science exposures in the ModelLibrary for i in science_indices: meta = input_models.read_metadata(i, flatten=False) mt_valid = _is_mt_meta_valid(meta) if not mt_valid: break # map index in science_indices to index in mt_ra/mt_dec arrays # mt_ra, mt_dec have length equal to number of science exposures j = science_indices.index(i) mt_ra[j] = meta["meta"]["wcsinfo"]["mt_ra"] mt_dec[j] = meta["meta"]["wcsinfo"]["mt_dec"] # Compute the mean MT RA/Dec over all exposures if not mt_valid: log.warning("One or more MT RA/Dec values missing in input images.") log.warning("Step will be skipped, resulting in target misalignment.") record_step_status(input_models, "assign_mtwcs", False) return input_models mt_avra = mt_ra.mean() mt_avdec = mt_dec.mean() with input_models: for i in science_indices: model = input_models.borrow(i) model.meta.wcsinfo.mt_avra = mt_avra model.meta.wcsinfo.mt_avdec = mt_avdec if isinstance(model, datamodels.MultiSlitModel): for ind, slit in enumerate(model.slits): new_wcs = add_mt_frame( slit.meta.wcs, mt_avra, mt_avdec, slit.meta.wcsinfo.mt_ra, slit.meta.wcsinfo.mt_dec, ) del model.slits[ind].meta.wcs model.slits[ind].meta.wcs = new_wcs else: new_wcs = add_mt_frame( model.meta.wcs, mt_avra, mt_avdec, model.meta.wcsinfo.mt_ra, model.meta.wcsinfo.mt_dec, ) del model.meta.wcs model.meta.wcs = new_wcs if model.meta.exposure.type.lower() in IMAGING_TYPES: update_s_region_imaging(model) record_step_status(model, "assign_mtwcs", True) input_models.shelve(model, i, modify=True) return input_models
[docs] def add_mt_frame(wcs, ra_average, dec_average, mt_ra, mt_dec): """ Add a "moving_target" frame to the WCS pipeline. Parameters ---------- wcs : `~gwcs.wcs.WCS` WCS object for the observation or slit. ra_average : float The average RA of all observations. dec_average : float The average DEC of all observations. mt_ra, mt_dec : float The RA, DEC of the moving target in the observation. Returns ------- new_wcs : `~gwcs.wcs.WCS` The WCS for the moving target observation. """ pipeline = wcs.pipeline[:-1] mt = deepcopy(wcs.output_frame) mt.name = "moving_target" rdel = ra_average - mt_ra ddel = dec_average - mt_dec if isinstance(mt, cf.CelestialFrame): transform_to_mt = Shift(rdel) & Shift(ddel) elif isinstance(mt, cf.CompositeFrame): transform_to_mt = Shift(rdel) & Shift(ddel) & Identity(1) else: raise TypeError(f"Unrecognized coordinate frame type {type(mt)}.") pipeline.append((wcs.output_frame, transform_to_mt)) pipeline.append((mt, None)) new_wcs = WCS(pipeline) return new_wcs
def _is_mt_meta_valid(meta): """ Check if the metadata contains valid moving target RA/DEC. Checks both the top-level wcsinfo as well as the wcsinfo for all the slits in a MultiSlitModel. Parameters ---------- meta : dict Nested metadata dictionary from a data model, as output from ``read_metadata`` with ``flatten=False``. Returns ------- bool True if valid, False otherwise. """ # check all the slits for slit in meta.get("slits", []): if ( slit["meta"]["wcsinfo"].get("mt_ra", None) is None or slit["meta"]["wcsinfo"].get("mt_dec", None) is None ): return False # check the top-level wcsinfo if ( meta["meta"]["wcsinfo"].get("mt_ra", None) is None or meta["meta"]["wcsinfo"].get("mt_dec", None) is None ): return False return True