import matplotlib.pyplot as plt
import numpy as np
from astropy.coordinates import Angle, SkyCoord
from astropy.wcs import WCS
from regions import CircleSkyRegion, CircleSphericalSkyRegion, PixCoord

# Create a full-sky Aitoff WCS
wcs = WCS(naxis=2)
wcs.wcs.crpix = (180, 90)
wcs.wcs.cdelt = (-1, 1)
wcs.wcs.crval = (0, 0)
wcs.wcs.ctype = ('GLON-AIT', 'GLAT-AIT')
shape = (180, 360)

# Define skycoords, pixcoords grids
lon = np.arange(-180, 181, 10)
lat = np.arange(-90, 91, 10)
coords = np.array(np.meshgrid(lon, lat)).T.reshape(-1, 2)
skycoords = SkyCoord(coords, unit='deg', frame='galactic')
pixcoords = PixCoord.from_sky(skycoords, wcs)

# Define spherical & planar sky circles
sph_circle = CircleSphericalSkyRegion(
    center=SkyCoord(50, 45, unit='deg', frame='galactic'),
    radius=Angle('30 deg'))
circle = CircleSkyRegion(
    center=SkyCoord(50, 45, unit='deg', frame='galactic'),
    radius=Angle('30 deg'))
# Note: circle is equivalent to transforming from sph_circle
# with sph_circle.to_sky(wcs=wcs, boundary_distortions=False)

# Define transformed-to pixel regions
pix_circ_distort = sph_circle.to_pixel(wcs=wcs,
                                       boundary_distortions=True,
                                       n_vertices=1000)
pix_circ_nodistort = circle.to_pixel(wcs=wcs)

# Get contained points
distort_mask = sph_circle.contains(skycoords)
nodistort_mask = pix_circ_nodistort.contains(pixcoords)

both_skycoords = skycoords[distort_mask & nodistort_mask]
distort_only_skycoords = skycoords[distort_mask & ~nodistort_mask]
nodistort_only_skycoords = skycoords[~distort_mask & nodistort_mask]

# Plot
fig = plt.figure()
fig.set_size_inches(7, 3.5)
ax = fig.add_axes([0.15, 0.1, 0.8, 0.8], projection=wcs, aspect='equal')

ax.scatter(skycoords.l.value, skycoords.b.value, label='All',
           transform=ax.get_transform('galactic'), color='lightgrey')
ax.scatter(distort_only_skycoords.l.value, distort_only_skycoords.b.value,
           color='magenta', label='Only within spherical circle',
           transform=ax.get_transform('galactic'))
ax.scatter(nodistort_only_skycoords.l.value,
           nodistort_only_skycoords.b.value,
           color='lime', label='Only within planar circle',
           transform=ax.get_transform('galactic'))
ax.scatter(both_skycoords.l.value, both_skycoords.b.value, color='orange',
           label='Within both', transform=ax.get_transform('galactic'))

pix_circ_distort.plot(ax=ax, edgecolor='red', facecolor='none',
                      alpha=0.8, lw=3)

pix_circ_nodistort.plot(ax=ax, edgecolor='green', facecolor='none',
                        alpha=0.8, lw=3)

ax.legend(loc='lower right')

ax.set_xlim(-0.5, shape[1] - 0.5)
ax.set_ylim(-0.5, shape[0] - 0.5)
ax.set_title('Spherical vs. Planar circle: '
             'Center=(50°, 45°), radius=30°')