# Planar
import matplotlib.pyplot as plt
import numpy as np
from astropy.coordinates import Angle, SkyCoord
from astropy.wcs import WCS
from regions import CircleSkyRegion

# Create a full-sky Aitoff WCS
wcs = WCS(naxis=2)
wcs.wcs.crpix = (180, 90)
wcs.wcs.cdelt = (-1, 1)
wcs.wcs.crval = (0, 0)
wcs.wcs.ctype = ('GLON-AIT', 'GLAT-AIT')
shape = (180, 360)

# Define 2 sky circles
circle1 = CircleSkyRegion(
    center=SkyCoord(20, 0, unit='deg', frame='galactic'),
    radius=Angle('30 deg'))

circle2 = CircleSkyRegion(
    center=SkyCoord(50, 45, unit='deg', frame='galactic'),
    radius=Angle('30 deg'))

# Define skycoords
lon = np.arange(-180, 181, 10)
lat = np.arange(-90, 91, 10)
coords = np.array(np.meshgrid(lon, lat)).T.reshape(-1, 2)
skycoords = SkyCoord(coords, unit='deg', frame='galactic')

# Get events in AND and XOR
compound_and = circle1 & circle2
compound_xor = circle1 ^ circle2

mask_and = compound_and.contains(skycoords, wcs)
skycoords_and = skycoords[mask_and]
mask_xor = compound_xor.contains(skycoords, wcs)
skycoords_xor = skycoords[mask_xor]

# Plot
fig = plt.figure()
fig.set_size_inches(7, 3.5)
ax = fig.add_axes([0.15, 0.1, 0.8, 0.8], projection=wcs, aspect='equal')

ax.scatter(skycoords.l.value, skycoords.b.value, label='all',
           transform=ax.get_transform('galactic'))
ax.scatter(skycoords_xor.l.value, skycoords_xor.b.value, color='orange',
           label='xor', transform=ax.get_transform('galactic'))
ax.scatter(skycoords_and.l.value, skycoords_and.b.value, color='magenta',
           label='and', transform=ax.get_transform('galactic'))

circle1.to_pixel(wcs=wcs).plot(ax=ax, edgecolor='green', facecolor='none',
                               alpha=0.8, lw=3)
circle2.to_pixel(wcs=wcs).plot(ax=ax, edgecolor='red', facecolor='none',
                               alpha=0.8, lw=3)

ax.legend(loc='lower right')

ax.set_xlim(-0.5, shape[1] - 0.5)
ax.set_ylim(-0.5, shape[0] - 0.5)
ax.set_title('Planar SkyRegions')