Make faster Union

This commit is contained in:
Kirill Andreev 2020-06-08 00:52:30 +04:00
parent 44defb2114
commit 20de97c5c6
No known key found for this signature in database
GPG Key ID: CF7DA79DE4785A47
2 changed files with 80 additions and 36 deletions

View File

@ -36,6 +36,10 @@ default-extensions:
- UndecidableInstances - UndecidableInstances
- FunctionalDependencies - FunctionalDependencies
- ViewPatterns - ViewPatterns
- ConstraintKinds
- TypeApplications
- AllowAmbiguousTypes
- MagicHash
ghc-options: -freverse-errors -Wall -threaded ghc-options: -freverse-errors -Wall -threaded

View File

@ -5,7 +5,7 @@
module Union module Union
( -- * Union type ( -- * Union type
Union(..) Union
, eliminate , eliminate
-- * Interface -- * Interface
@ -15,61 +15,101 @@ module Union
) )
where where
import Data.Kind
import Data.Function (on)
import GHC.Exts
import GHC.TypeLits
import Unsafe.Coerce
import Pretty import Pretty
-- | The "one of" datatype. -- | The "one of" datatype.
-- --
-- Each @Union fs a@ is a @f a@, where @f@ is one of @fs@`. -- Each @Union fs a@ is a @f a@, where @f@ is one of @fs@`.
data Union fs x where data Union (fs :: [* -> *]) x where
Here :: f x -> Union (f : fs) x MkUnion :: Integer -> Any fs x -> Union fs x
There :: Union fs x -> Union (f : fs) x
instance Eq (Union '[] a) where (==) = error "Union.empty" type family Find (f :: * -> *) fs :: Nat where
instance Show (Union '[] a) where show = error "Union.empty" Find f (f : _) = 0
instance Functor (Union '[]) where fmap = error "Union.empty" Find f (_ : fs) = 1 + Find f fs
instance Foldable (Union '[]) where foldMap = error "Union.empty"
instance Traversable (Union '[]) where traverse = error "Union.empty"
instance (Eq (f a), Eq (Union fs a)) => Eq (Union (f : fs) a) where type family Len (fs :: [* -> *]) :: Nat where
a == b = case (a, b) of Len '[] = 0
(Here a', Here b') -> a' == b' Len (_ : fs) = 1 + Len fs
(There a', There b') -> a' == b'
_ -> False
instance (Show (f a), Show (Union fs a)) => Show (Union (f : fs) a) where type Member f fs = KnownNat (Find f fs)
show = eliminate show show type KnownList fs = KnownNat (Len fs)
deriving stock instance (Functor f, Functor (Union fs)) => Functor (Union (f : fs)) val :: forall n. KnownNat n => Integer
deriving stock instance (Foldable f, Foldable (Union fs)) => Foldable (Union (f : fs)) val = natVal' (proxy# :: Proxy# n)
deriving stock instance (Traversable f, Traversable (Union fs)) => Traversable (Union (f : fs))
-- | A case over `Union`. inj
:: forall f fs n x
. ( n ~ Find f fs
, KnownNat n
)
=> f x
-> Union fs x
inj fx = MkUnion (val @n) (unsafeCoerce fx)
raise
:: Union fs x
-> Union (f : fs) x
raise (MkUnion i b) = MkUnion (i + 1) (unsafeCoerce b)
proj
:: forall f fs n x
. ( n ~ Find f fs
, KnownNat n
)
=> Union fs x
-> Maybe (f x)
proj (MkUnion i body)
| i == val @n = Just $ unsafeCoerce body
| otherwise = Nothing
split
:: Union (f : fs) x
-> Either (f x) (Union fs x)
split = eliminate Left Right
vacuum :: Union '[] a -> b
vacuum = error "Empty union"
-- | A case-split over `Union`.
eliminate eliminate
:: (f x -> a) :: (f x -> a)
-> (Union fs x -> a) -> (Union fs x -> a)
-> (Union (f : fs) x -> a) -> (Union (f : fs) x -> a)
eliminate here there = \case eliminate here there (MkUnion i body)
Here fx -> here fx | i == 0 = here $ unsafeCoerce body
There rest -> there rest | otherwise = there $ MkUnion (i - 1) (unsafeCoerce body)
-- | The @f@ functior is in the @fs@ list. instance Eq (Union '[] a) where (==) = vacuum
class Member f fs where instance Show (Union '[] a) where show = vacuum
-- | Embed @f@ into some `Union`. instance Functor (Union '[]) where fmap _ = vacuum
inj :: f x -> Union fs x instance Foldable (Union '[]) where foldMap _ = vacuum
instance Traversable (Union '[]) where traverse _ = vacuum
-- | Check if a `Union` is actually @f@. instance (Eq (f a), Eq (Union fs a)) => Eq (Union (f : fs) a) where
proj :: Union fs x -> Maybe (f x) (==) = (==) `on` split
instance {-# OVERLAPS #-} Member f (f : fs) where instance (Show (f a), Show (Union fs a)) => Show (Union (f : fs) a) where
inj = Here show = eliminate show show
proj = eliminate Just (const Nothing)
instance Member f fs => Member f (g : fs) where instance (Functor f, Functor (Union fs)) => Functor (Union (f : fs)) where
inj = There . inj fmap f = eliminate (inj . fmap f) (raise . fmap f)
proj = eliminate (const Nothing) proj
instance (Foldable f, Foldable (Union fs)) => Foldable (Union (f : fs)) where
foldMap f = eliminate (foldMap f) (foldMap f)
instance (Traversable f, Traversable (Union fs)) => Traversable (Union (f : fs)) where
traverse f = eliminate (fmap inj . traverse f) (fmap raise . traverse f)
instance Pretty1 (Union '[]) where instance Pretty1 (Union '[]) where
pp1 = error "Union.empty" pp1 = vacuum
instance (Pretty1 f, Pretty1 (Union fs)) => Pretty1 (Union (f : fs)) where instance (Pretty1 f, Pretty1 (Union fs)) => Pretty1 (Union (f : fs)) where
pp1 = eliminate pp1 pp1 pp1 = eliminate pp1 pp1