import warnings

import matplotlib.pyplot as plt
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.visualization.wcsaxes.frame import EllipticalFrame

from regions import (CircleSphericalSkyRegion,
                     RangeSphericalSkyRegion,
                     make_example_dataset)

dataset = make_example_dataset(data='simulated')
wcs = dataset.wcs

config = {'ctype': ('RA---AIT', 'DEC--AIT')}
dataset_icrs = make_example_dataset(data='simulated', config=config)
wcs_icrs = dataset_icrs.wcs

sph_circ = CircleSphericalSkyRegion(
    SkyCoord(100, -30, unit=u.deg, frame='galactic'),
    30 * u.deg)
sph_range = RangeSphericalSkyRegion(
    frame='galactic',
    longitude_range=[315, 45] * u.deg,
    latitude_range=[0, 45] * u.deg)
sph_circ_transf = sph_circ.transform_to('icrs')
sph_range_transf = sph_range.transform_to('icrs')

fig = plt.figure()
fig.set_size_inches(7, 7)

axes = []
axes.append(fig.add_axes([0.15, 0.575, 0.8, 0.425],
                         projection=wcs,
                         frame_class=EllipticalFrame))
axes.append(fig.add_axes([0.15, 0.05, 0.8, 0.425],
                         projection=wcs_icrs,
                         frame_class=EllipticalFrame))

ax = axes[0]
ax.coords.grid(color='black')
ax.set_xlabel(r'Galactic $\ell$', labelpad=10)
ax.set_ylabel(r'Galactic $b$')
ax.set_title('Galactic coordinates', pad=5)

overlay = ax.get_coords_overlay('icrs')
overlay.grid(color='gray', ls='dotted')

patch = sph_circ.to_pixel(
    wcs=wcs,
    include_boundary_distortions=True,
    n_points=1000,
).plot(ax=ax, color='tab:blue', lw=3)

sph_range.to_pixel(
    wcs=wcs,
    include_boundary_distortions=True,
    n_points=250,
).plot(ax=ax, color='tab:red', lw=3)

patch.set_clip_path(ax.coords.frame.patch)

ax.set_xlim(20, 340)
ax.set_ylim(10, 170)

ax = axes[1]
ax.coords.grid(color='gray', ls='dotted')
ax.set_xlabel('RA', labelpad=10)
ax.set_ylabel('Dec')
ax.set_title('ICRS coordinates', pad=5)

overlay = ax.get_coords_overlay('galactic')
overlay.grid(color='black', ls='solid')

patch = sph_circ_transf.to_pixel(
    wcs=wcs_icrs,
    include_boundary_distortions=True,
    n_points=1000,
).plot(ax=ax, color='tab:blue', lw=3)

sph_range_transf.to_pixel(
    wcs=wcs_icrs,
    include_boundary_distortions=True,
    n_points=250,
).plot(ax=ax, color='tab:red', lw=3)

patch.set_clip_path(ax.coords.frame.patch)

ax.set_xlim(20, 340)
ax.set_ylim(10, 170)
ax.coords[0].set_format_unit(u.deg)