diff --git a/tools/lsp/squirrel/package.yaml b/tools/lsp/squirrel/package.yaml index b54b056c1..96084fd20 100644 --- a/tools/lsp/squirrel/package.yaml +++ b/tools/lsp/squirrel/package.yaml @@ -36,6 +36,10 @@ default-extensions: - UndecidableInstances - FunctionalDependencies - ViewPatterns + - ConstraintKinds + - TypeApplications + - AllowAmbiguousTypes + - MagicHash ghc-options: -freverse-errors -Wall -threaded diff --git a/tools/lsp/squirrel/src/Union.hs b/tools/lsp/squirrel/src/Union.hs index f50379a29..691ec7cc6 100644 --- a/tools/lsp/squirrel/src/Union.hs +++ b/tools/lsp/squirrel/src/Union.hs @@ -5,7 +5,7 @@ module Union ( -- * Union type - Union(..) + Union , eliminate -- * Interface @@ -15,61 +15,101 @@ module Union ) where +import Data.Kind +import Data.Function (on) + +import GHC.Exts +import GHC.TypeLits + +import Unsafe.Coerce + import Pretty -- | The "one of" datatype. -- -- Each @Union fs a@ is a @f a@, where @f@ is one of @fs@`. -data Union fs x where - Here :: f x -> Union (f : fs) x - There :: Union fs x -> Union (f : fs) x +data Union (fs :: [* -> *]) x where + MkUnion :: Integer -> Any fs x -> Union fs x -instance Eq (Union '[] a) where (==) = error "Union.empty" -instance Show (Union '[] a) where show = error "Union.empty" -instance Functor (Union '[]) where fmap = error "Union.empty" -instance Foldable (Union '[]) where foldMap = error "Union.empty" -instance Traversable (Union '[]) where traverse = error "Union.empty" +type family Find (f :: * -> *) fs :: Nat where + Find f (f : _) = 0 + Find f (_ : fs) = 1 + Find f fs -instance (Eq (f a), Eq (Union fs a)) => Eq (Union (f : fs) a) where - a == b = case (a, b) of - (Here a', Here b') -> a' == b' - (There a', There b') -> a' == b' - _ -> False +type family Len (fs :: [* -> *]) :: Nat where + Len '[] = 0 + Len (_ : fs) = 1 + Len fs -instance (Show (f a), Show (Union fs a)) => Show (Union (f : fs) a) where - show = eliminate show show +type Member f fs = KnownNat (Find f fs) +type KnownList fs = KnownNat (Len fs) -deriving stock instance (Functor f, Functor (Union fs)) => Functor (Union (f : fs)) -deriving stock instance (Foldable f, Foldable (Union fs)) => Foldable (Union (f : fs)) -deriving stock instance (Traversable f, Traversable (Union fs)) => Traversable (Union (f : fs)) +val :: forall n. KnownNat n => Integer +val = natVal' (proxy# :: Proxy# n) --- | A case over `Union`. +inj + :: forall f fs n x + . ( n ~ Find f fs + , KnownNat n + ) + => f x + -> Union fs x +inj fx = MkUnion (val @n) (unsafeCoerce fx) + +raise + :: Union fs x + -> Union (f : fs) x +raise (MkUnion i b) = MkUnion (i + 1) (unsafeCoerce b) + +proj + :: forall f fs n x + . ( n ~ Find f fs + , KnownNat n + ) + => Union fs x + -> Maybe (f x) +proj (MkUnion i body) + | i == val @n = Just $ unsafeCoerce body + | otherwise = Nothing + +split + :: Union (f : fs) x + -> Either (f x) (Union fs x) +split = eliminate Left Right + +vacuum :: Union '[] a -> b +vacuum = error "Empty union" + +-- | A case-split over `Union`. eliminate :: (f x -> a) -> (Union fs x -> a) -> (Union (f : fs) x -> a) -eliminate here there = \case - Here fx -> here fx - There rest -> there rest +eliminate here there (MkUnion i body) + | i == 0 = here $ unsafeCoerce body + | otherwise = there $ MkUnion (i - 1) (unsafeCoerce body) --- | The @f@ functior is in the @fs@ list. -class Member f fs where - -- | Embed @f@ into some `Union`. - inj :: f x -> Union fs x +instance Eq (Union '[] a) where (==) = vacuum +instance Show (Union '[] a) where show = vacuum +instance Functor (Union '[]) where fmap _ = vacuum +instance Foldable (Union '[]) where foldMap _ = vacuum +instance Traversable (Union '[]) where traverse _ = vacuum - -- | Check if a `Union` is actually @f@. - proj :: Union fs x -> Maybe (f x) +instance (Eq (f a), Eq (Union fs a)) => Eq (Union (f : fs) a) where + (==) = (==) `on` split -instance {-# OVERLAPS #-} Member f (f : fs) where - inj = Here - proj = eliminate Just (const Nothing) +instance (Show (f a), Show (Union fs a)) => Show (Union (f : fs) a) where + show = eliminate show show -instance Member f fs => Member f (g : fs) where - inj = There . inj - proj = eliminate (const Nothing) proj +instance (Functor f, Functor (Union fs)) => Functor (Union (f : fs)) where + fmap f = eliminate (inj . fmap f) (raise . fmap f) + +instance (Foldable f, Foldable (Union fs)) => Foldable (Union (f : fs)) where + foldMap f = eliminate (foldMap f) (foldMap f) + +instance (Traversable f, Traversable (Union fs)) => Traversable (Union (f : fs)) where + traverse f = eliminate (fmap inj . traverse f) (fmap raise . traverse f) instance Pretty1 (Union '[]) where - pp1 = error "Union.empty" + pp1 = vacuum instance (Pretty1 f, Pretty1 (Union fs)) => Pretty1 (Union (f : fs)) where pp1 = eliminate pp1 pp1 \ No newline at end of file