Make faster Union

This commit is contained in:
Kirill Andreev 2020-06-08 00:52:30 +04:00
parent 44defb2114
commit 20de97c5c6
No known key found for this signature in database
GPG Key ID: CF7DA79DE4785A47
2 changed files with 80 additions and 36 deletions

View File

@ -36,6 +36,10 @@ default-extensions:
- UndecidableInstances
- FunctionalDependencies
- ViewPatterns
- ConstraintKinds
- TypeApplications
- AllowAmbiguousTypes
- MagicHash
ghc-options: -freverse-errors -Wall -threaded

View File

@ -5,7 +5,7 @@
module Union
( -- * Union type
Union(..)
Union
, eliminate
-- * Interface
@ -15,61 +15,101 @@ module Union
)
where
import Data.Kind
import Data.Function (on)
import GHC.Exts
import GHC.TypeLits
import Unsafe.Coerce
import Pretty
-- | The "one of" datatype.
--
-- Each @Union fs a@ is a @f a@, where @f@ is one of @fs@`.
data Union fs x where
Here :: f x -> Union (f : fs) x
There :: Union fs x -> Union (f : fs) x
data Union (fs :: [* -> *]) x where
MkUnion :: Integer -> Any fs x -> Union fs x
instance Eq (Union '[] a) where (==) = error "Union.empty"
instance Show (Union '[] a) where show = error "Union.empty"
instance Functor (Union '[]) where fmap = error "Union.empty"
instance Foldable (Union '[]) where foldMap = error "Union.empty"
instance Traversable (Union '[]) where traverse = error "Union.empty"
type family Find (f :: * -> *) fs :: Nat where
Find f (f : _) = 0
Find f (_ : fs) = 1 + Find f fs
instance (Eq (f a), Eq (Union fs a)) => Eq (Union (f : fs) a) where
a == b = case (a, b) of
(Here a', Here b') -> a' == b'
(There a', There b') -> a' == b'
_ -> False
type family Len (fs :: [* -> *]) :: Nat where
Len '[] = 0
Len (_ : fs) = 1 + Len fs
instance (Show (f a), Show (Union fs a)) => Show (Union (f : fs) a) where
show = eliminate show show
type Member f fs = KnownNat (Find f fs)
type KnownList fs = KnownNat (Len fs)
deriving stock instance (Functor f, Functor (Union fs)) => Functor (Union (f : fs))
deriving stock instance (Foldable f, Foldable (Union fs)) => Foldable (Union (f : fs))
deriving stock instance (Traversable f, Traversable (Union fs)) => Traversable (Union (f : fs))
val :: forall n. KnownNat n => Integer
val = natVal' (proxy# :: Proxy# n)
-- | A case over `Union`.
inj
:: forall f fs n x
. ( n ~ Find f fs
, KnownNat n
)
=> f x
-> Union fs x
inj fx = MkUnion (val @n) (unsafeCoerce fx)
raise
:: Union fs x
-> Union (f : fs) x
raise (MkUnion i b) = MkUnion (i + 1) (unsafeCoerce b)
proj
:: forall f fs n x
. ( n ~ Find f fs
, KnownNat n
)
=> Union fs x
-> Maybe (f x)
proj (MkUnion i body)
| i == val @n = Just $ unsafeCoerce body
| otherwise = Nothing
split
:: Union (f : fs) x
-> Either (f x) (Union fs x)
split = eliminate Left Right
vacuum :: Union '[] a -> b
vacuum = error "Empty union"
-- | A case-split over `Union`.
eliminate
:: (f x -> a)
-> (Union fs x -> a)
-> (Union (f : fs) x -> a)
eliminate here there = \case
Here fx -> here fx
There rest -> there rest
eliminate here there (MkUnion i body)
| i == 0 = here $ unsafeCoerce body
| otherwise = there $ MkUnion (i - 1) (unsafeCoerce body)
-- | The @f@ functior is in the @fs@ list.
class Member f fs where
-- | Embed @f@ into some `Union`.
inj :: f x -> Union fs x
instance Eq (Union '[] a) where (==) = vacuum
instance Show (Union '[] a) where show = vacuum
instance Functor (Union '[]) where fmap _ = vacuum
instance Foldable (Union '[]) where foldMap _ = vacuum
instance Traversable (Union '[]) where traverse _ = vacuum
-- | Check if a `Union` is actually @f@.
proj :: Union fs x -> Maybe (f x)
instance (Eq (f a), Eq (Union fs a)) => Eq (Union (f : fs) a) where
(==) = (==) `on` split
instance {-# OVERLAPS #-} Member f (f : fs) where
inj = Here
proj = eliminate Just (const Nothing)
instance (Show (f a), Show (Union fs a)) => Show (Union (f : fs) a) where
show = eliminate show show
instance Member f fs => Member f (g : fs) where
inj = There . inj
proj = eliminate (const Nothing) proj
instance (Functor f, Functor (Union fs)) => Functor (Union (f : fs)) where
fmap f = eliminate (inj . fmap f) (raise . fmap f)
instance (Foldable f, Foldable (Union fs)) => Foldable (Union (f : fs)) where
foldMap f = eliminate (foldMap f) (foldMap f)
instance (Traversable f, Traversable (Union fs)) => Traversable (Union (f : fs)) where
traverse f = eliminate (fmap inj . traverse f) (fmap raise . traverse f)
instance Pretty1 (Union '[]) where
pp1 = error "Union.empty"
pp1 = vacuum
instance (Pretty1 f, Pretty1 (Union fs)) => Pretty1 (Union (f : fs)) where
pp1 = eliminate pp1 pp1