import warnings

warnings.filterwarnings('ignore', message='.*unrecognized position.*')

import matplotlib.pyplot as plt
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from astropy.wcs import WCS

from regions import (CircleSphericalSkyRegion,
                     RangeSphericalSkyRegion)

# Create full-sky Aitoff WCS objects for Galactic and ICRS frames
wcs = WCS(naxis=2)
wcs.wcs.crpix = (180, 90)
wcs.wcs.cdelt = (-1, 1)
wcs.wcs.crval = (0, 0)
wcs.wcs.ctype = ('GLON-AIT', 'GLAT-AIT')

wcs_icrs = WCS(naxis=2)
wcs_icrs.wcs.crpix = (180, 90)
wcs_icrs.wcs.cdelt = (-1, 1)
wcs_icrs.wcs.crval = (0, 0)
wcs_icrs.wcs.ctype = ('RA---AIT', 'DEC--AIT')

sph_circ = CircleSphericalSkyRegion(
    SkyCoord(100, -30, unit=u.deg, frame='galactic'),
    30 * u.deg)
sph_range = RangeSphericalSkyRegion(
    frame='galactic',
    longitude_range=[315, 45] * u.deg,
    latitude_range=[0, 45] * u.deg)
sph_circ_transf = sph_circ.transform_to('icrs')
sph_range_transf = sph_range.transform_to('icrs')

fig = plt.figure()
fig.set_size_inches(7, 7)

axes = []
axes.append(fig.add_axes([0.15, 0.575, 0.8, 0.425],
                         projection=wcs,
                         frame_class=EllipticalFrame))
axes.append(fig.add_axes([0.15, 0.05, 0.8, 0.425],
                         projection=wcs_icrs,
                         frame_class=EllipticalFrame))

ax = axes[0]
ax.coords.grid(color='black')
ax.set_xlabel(r'Galactic $\ell$', labelpad=10)
ax.set_ylabel(r'Galactic $b$')
ax.set_title('Galactic coordinates', pad=5)

overlay = ax.get_coords_overlay('icrs')
overlay.grid(color='gray', ls='dotted')

patch = sph_circ.to_pixel(
    wcs=wcs,
    boundary_distortions=True,
    n_vertices=1000,
).plot(ax=ax, color='tab:blue', lw=3)

sph_range.to_pixel(
    wcs=wcs,
    boundary_distortions=True,
    n_vertices=250,
).plot(ax=ax, color='tab:red', lw=3)

patch.set_clip_path(ax.coords.frame.patch)

ax.set_xlim(20, 340)
ax.set_ylim(10, 170)

ax = axes[1]
ax.coords.grid(color='gray', ls='dotted')
ax.set_xlabel('RA', labelpad=10)
ax.set_ylabel('Dec')
ax.set_title('ICRS coordinates', pad=5)

overlay = ax.get_coords_overlay('galactic')
overlay.grid(color='black', ls='solid')

patch = sph_circ_transf.to_pixel(
    wcs=wcs_icrs,
    boundary_distortions=True,
    n_vertices=1000,
).plot(ax=ax, color='tab:blue', lw=3)

sph_range_transf.to_pixel(
    wcs=wcs_icrs,
    boundary_distortions=True,
    n_vertices=250,
).plot(ax=ax, color='tab:red', lw=3)

patch.set_clip_path(ax.coords.frame.patch)

ax.set_xlim(20, 340)
ax.set_ylim(10, 170)
ax.coords[0].set_format_unit(u.deg)