page =
url = https://byorgey.wordpress.com
blog :: brent -> [string]
For a current project, I needed to implement type inference for a Hindley-Milner -based type system. (More about that project in an upcoming post!) If you don’t know, Hindley-Milner is what you get when you add polymorphism to the simply typed lambda calculus, but only allow to show up at the very outermost layer of a type. This is the fundamental basis for many real-world type systems ( e.g. OCaml or Haskell without
RankNTypes
enabled).
One of the core operations in any Hindley-Milner type inference algorithm is unification , where we take two types that might contain unification variables (think “named holes”) and try to make them equal, which might fail, or might provide more information about the values that the unification variables should take. For example, if we try to unify
a -> Int
Char -> b
, we will learn that
a = Char
b = Int
; on the other hand, trying to unify
(b, Char)
will fail, because there is no way to make those two types equal (the first is a function type whereas the second is a pair).
I’ve implemented this from scratch before, and it was a great learning experience, but I wasn’t looking forward to implementing it again. But then I remembered the unification-fd library and wondered whether I could use it to simplify the implementation. Long story short, although the documentation for
unification-fd
claims it can be used to implement Hindley-Milner, I couldn’t find any examples online, and apparently neither could anyone else . So I set out to make my own example, and you’re reading it. It turns out that
is incredibly powerful, but using it can be a bit finicky, so I hope this example can be helpful to others who wish to use it. Along the way, resources I found especially helpful include this basic unification-fd tutorial by the author, Wren Romano, as well as a blog post by Roman Cheplyaka , and the unification-fd Haddock documentation itself. I also referred to the Wikipedia page on Hindley-Milner , which is extremely thorough.
This blog post is rendered automatically from a literate Haskell file; you can find the complete working source code and blog post on GitHub . I’m always happy to receive comments, fixes, or suggestions for improvement.
Prelude: A Bunch of Extensions and Imports
We will make GHC and other people’s libraries work very hard for us. You know the drill.
> {-# LANGUAGE DeriveAnyClass #-} > {-# LANGUAGE DeriveFoldable #-} > {-# LANGUAGE DeriveFunctor #-} > {-# LANGUAGE DeriveGeneric #-} > {-# LANGUAGE DeriveTraversable #-} > {-# LANGUAGE FlexibleContexts #-} > {-# LANGUAGE FlexibleInstances #-} > {-# LANGUAGE GADTs #-} > {-# LANGUAGE LambdaCase #-} > {-# LANGUAGE MultiParamTypeClasses #-} > {-# LANGUAGE PatternSynonyms #-} > {-# LANGUAGE StandaloneDeriving #-} > {-# LANGUAGE UndecidableInstances #-} > > import Control.Category ( ( >>> ) ) > import Control.Monad.Except > import Control.Monad.Reader > import Data.Foldable ( fold ) > import Data.Functor.Identity > import Data.List ( intercalate ) > import Data.Map ( Map ) > import qualified Data.Map as M > import Data.Maybe > import Data.Set ( Set , ( \\ ) ) > import qualified Data.Set as S > import Prelude hiding ( lookup ) > import Text.Printf > > import Text.Parsec > import Text.Parsec.Expr > import Text.Parsec.Language ( emptyDef ) > import Text.Parsec.String > import qualified Text.Parsec.Token as L > > import Control.Unification hiding ( ( =:= ) , applyBindings ) > import qualified Control.Unification as U > import Control.Unification.IntVar > import Data.Functor.Fixedpoint > > import GHC.Generics ( Generic1 ) > > import System.Console.Repline
Representing our types
We’ll be implementing a language with lambas, application, and let-expressions—as well as natural numbers with an addition operation, just to give us a base type and something to do with it. So we will have a natural number type and function types, along with polymorphism, i.e. type variables and
forall
. (Adding more features like sum and product types, additional base types, recursion, etc. is left as an exercise for the reader!)
So notionally, we want something like this:
type Var = String data Type = TyVar Var | TyNat | TyFun Type Type
However, when using
, we have to encode our
Type
data type ( i.e. the thing we want to do unification on) using a “two-level type” (see Tim Sheard’s original paper ).
> type Var = String > data TypeF a = TyVarF Var | TyNatF | TyFunF a a > deriving ( Show , Eq , Functor , Foldable , Traversable , Generic1 , Unifiable ) > > type Type = Fix TypeF
TypeF
is a “structure functor” that just defines a single level of structure; notice
is not recursive at all, but uses the type parameter
to mark the places where a recursive instance would usually be.
provides a
Fix
type to “tie the knot” and make it recursive. (I’m not sure why
defines its own
type instead of using the one from
Data.Fix
, but perhaps the reason is that it was written before
existed—
was first published in 2007!)
We have to derive a whole bunch of instances for
which fortunately we get for free. Note in particular
Generic1
Unifiable
is a type class from
which describes how to match up values of our type. Thanks to the work of Roman Cheplyaka , there is a default implementation for
based on a
instance—which GHC derives for us in turn—and the default implementation works great for our purposes.
also provides a second type for tying the knot, called
UTerm
, defined like so:
data UTerm t v = UVar ! v -- ^ A unification variable. | UTerm ! ( t ( UTerm t v ) ) -- ^ Some structure containing subterms.
It’s similar to
, except it also adds unification variables of some type
v
. (If it means anything to you, note that
is actually the free monad over
t
.) We also define a version of
, which we will use during type inference:
> type UType = UTerm TypeF IntVar
IntVar
is a type provided by
representing variables as
values, with a mapping from variables to bindings stored in an
IntMap
also provies an
STVar
type which implements variables via
STRef
s; I presume using
s would be faster (since no intermediate lookups in an
are required) but forces us to work in the
monad. For now I will just stick with
, which makes things simpler.
At this point you might wonder: why do we have type variables in our definition of
, but also use
to add unification variables? Can’t we just get rid of the
TyVarF
constructor, and let
provide the variables? Well, type variables and unification variables are subtly different, though intimately related. A type variable is actually part of a type, whereas a unification variable is not itself a type, but only stands for some type which is (as yet) unknown. After we are completely done with type inference, we won’t have a
any more, but we might have a type like
forall a. a -> a
which still contains type variables, so we need a way to represent them. We could only get rid of the
constructor if we were doing type inference for a language without polymorphism (and yes, unification could still be helpful in such a situation—for example, to do full type reconstruction for the simply typed lambda calculus, where lambdas do not have type annotations).
Polytype
represents a polymorphic type, with a
at the front and a list of bound type variables (note that regular monomorphic types can be represented as
Forall [] ty
). We don’t need to make an instance of
for
, since we never unify polytypes, only (mono)types. However, we can have polytypes with unification variables in them, so we need two versions, one containing a
and one containing a
UType
> data Poly t = Forall [ Var ] t > deriving ( Eq , Show , Functor ) > type Polytype = Poly Type > type UPolytype = Poly UType
Finally, for convenience, we can make a bunch of pattern synonyms that let us work with
just as if they were directly recursive types. This isn’t required; I just like not having to write
everywhere.
> pattern TyVar :: Var -> Type > pattern TyVar v = Fix ( TyVarF v ) > > pattern TyNat :: Type > pattern TyNat = Fix TyNatF > > pattern TyFun :: Type -> Type -> Type > pattern TyFun t1 t2 = Fix ( TyFunF t1 t2 ) > > pattern UTyNat :: UType > pattern UTyNat = UTerm TyNatF > > pattern UTyFun :: UType -> UType -> UType > pattern UTyFun t1 t2 = UTerm ( TyFunF t1 t2 ) > > pattern UTyVar :: Var -> UType > pattern UTyVar v = UTerm ( TyVarF v )
Expressions
Here’s a data type to represent expressions. There’s nothing much interesting to see here, since we don’t need to do anything fancy with expressions. Note that lambdas don’t have type annotations, but
let
-expressions can have an optional polytype annotation, which will let us talk about checking polymorphic types in addition to inferring them (a lot of presentations of Hindley-Milner don’t talk about this).
> data Expr where > EVar :: Var -> Expr > EInt :: Integer -> Expr > EPlus :: Expr -> Expr -> Expr > ELam :: Var -> Expr -> Expr > EApp :: Expr -> Expr -> Expr > ELet :: Var -> Maybe Polytype -> Expr -> Expr -> Expr
Normally at this point we would write parsers and pretty-printers for types and expressions, but that’s boring and has very little to do with
, so I’ve left those to the end. Let’s get on with the interesting bits!
Type inference infrastructure
Before we get to the type inference algorithm proper, we’ll need to develop a bunch of infrastructure. First, here’s the concrete monad we will be using for type inference. The
ReaderT Ctx
will keep track of variables and their types;
ExceptT TypeError
of course allows us to fail with type errors; and
IntBindingT
is a monad transformer provided by
which supports various operations such as generating fresh variables and unifying things. Note, for reasons that will become clear later, it’s very important that the
is on the bottom of the stack, and the
ExceptT
comes right above it. Beyond that we can add whatever we like on top.
> type Infer = ReaderT Ctx ( ExceptT TypeError ( IntBindingT TypeF Identity ) )
Normally, I would prefer to write everything in a “capability style” where the code is polymorphic in the monad, and just specifies what capabilites/effects it needs (either just using
mtl
classes directly, or using an effects library like
polysemy
or
fused-effects
), but the way the
API is designed seems to make that a bit tricky.
A type context is a mapping from variable names to polytypes; we also have a function for looking up the type of a variable in the context, and a function for running a local subcomputation in an extended context.
> type Ctx = Map Var UPolytype > > lookup :: Var -> Infer UType > lookup x = do > ctx <- ask > maybe ( throwError $ UnboundVar x ) instantiate ( M.lookup x ctx ) > > withBinding :: MonadReader Ctx m => Var -> UPolytype -> m a -> m a > withBinding x ty = local ( M.insert x ty )
lookup
function throws an error if the variable is not in the context, and otherwise returns a
. Conversion from the
UPolytype
stored in the context to a
happens via a function called
instantiate
, which opens up the
and replaces each of the variables bound by the
with a fresh unification variable. We will see the implementation of
later.
We will often need to recurse over
s. We could just write directly recursive functions ourselves, but there is a better way. Although the
library provides a function
cata
for doing a fold over a term built with
, it doesn’t provide a counterpart for
; but no matter, we can write one ourselves, like so:
> ucata :: Functor t => ( v -> a ) -> ( t a -> a ) -> UTerm t v -> a > ucata f _ ( UVar v ) = f v > ucata f g ( UTerm t ) = g ( fmap ( ucata f g ) t )
Now, we can write some utilities for finding free variables. Inexplicably,
does not have an
Ord
instance (even though it is literally just a
over
), so we have to derive one if we want to store them in a
. Notice that our
freeVars
function finds free unification variables and free type variables; I will talk about why we need this later (this is something I got wrong at first!).
> deriving instance Ord IntVar > > class FreeVars a where > freeVars :: a -> Infer ( Set ( Either Var IntVar ) )
Finding the free variables in a
is our first application of
ucata
. First, to find the free unification variables, we just use the
getFreeVars
function provided by
and massage the output into the right form. To find free type variables, we fold using
: we ignore unification variables, capture a singleton set in the
case, and in the recursive case we call
fold
, which will turn a
TypeF (Set ...)
into a
Set ...
using the
, i.e.
union
> instance FreeVars UType where > freeVars ut = do > fuvs <- fmap ( S.fromList . map Right ) . lift . lift $ getFreeVars ut > let ftvs = ucata ( const S.empty ) > ( \ case { TyVarF x -> S.singleton ( Left x ) ; f -> fold f } ) > ut > return $ fuvs `S.union` ftvs
Why don’t we just find free unification variables with
at the same time as the free type variables, and forget about using
? Well, I looked at the source, and
is actually a complicated beast. I’m really not sure what it’s doing, and I don’t trust that just manually getting the unification variables ourselves would be doing the right thing!
Now we can leverage the above instance to find free varaibles in
s and type contexts. For a
, we of course have to subtract off any variables bound by the
> instance FreeVars UPolytype where > freeVars ( Forall xs ut ) = ( \\ ( S.fromList ( map Left xs ) ) ) <$> freeVars ut > > instance FreeVars Ctx where > freeVars = fmap S.unions . mapM freeVars . M.elems
And here’s a simple utility function to generate fresh unification variables, built on top of the
freeVar
> fresh :: Infer UType > fresh = UVar <$> lift ( lift freeVar )
One thing to note is the annoying calls to
lift
we have to do in the definition of
FreeVars
, and in the definition of
fresh
functions provided by
unification-fv
have to run in a monad which is an instance of
BindingMonad
, and the
class does not provide instances for
transformers. We could write our own instances so that these functions would work in our
Infer
monad automatically, but honestly that sounds like a lot of work. Sprinkling a few
s here and there isn’t so bad.
Next, a data type to represent type errors, and an instance of the
Fallible
class, needed so that
can inject errors into our error type when it encounters unification errors. Basically we just need to provide two specific constructors to represent an “occurs check” failure ( i.e. an infinite type), or a unification mismatch failure.
> data TypeError where > UnboundVar :: String -> TypeError > Infinite :: IntVar -> UType -> TypeError > Mismatch :: TypeF UType -> TypeF UType -> TypeError > > instance Fallible TypeF IntVar TypeError where > occursFailure = Infinite > mismatchFailure = Mismatch
=:=
operator provided by
is how we unify two types. It has a kind of bizarre type:
( =:= ) :: ( BindingMonad t v m , Fallible t v e , MonadTrans em , Functor ( em m ) , MonadError e ( em m ) ) => UTerm t v -> UTerm t v -> em m ( UTerm t v )
(Apparently I am not the only one who thinks this type is bizarre; the
source code contains the comment
-- TODO: what was the reason for the MonadTrans madness?
)
I had to stare at this for a while to understand it. It says that the output will be in some
(such as
), and there must be a single error monad transformer on top of it, with an error type that implements
. So
can return
ExceptT TypeError (IntBindingT Identity) UType
, but it cannot be used directly in our
monad, because that has a
ReaderT
on top of the
. So I just made my own version with an extra
to get it to work directly in the
monad. While we’re at it, we’ll make a lifted version of
applyBindings
, which has the same issue.
> ( =:= ) :: UType -> UType -> Infer UType > s =:= t = lift $ s U .=:= t > > applyBindings :: UType -> Infer UType > applyBindings = lift . U.applyBindings
Converting between mono- and polytypes
Central to the way Hindley-Milner works is the way we move back and forth between polytypes and monotypes. First, let’s see how to turn
s into
s, hinted at earlier in the definition of the
function. To
, we generate a fresh unification variable for each variable bound by the
Forall
, and then substitute them throughout the type.
> instantiate :: UPolytype -> Infer UType > instantiate ( Forall xs uty ) = do > xs' <- mapM ( const fresh ) xs > return $ substU ( M.fromList ( zip ( map Left xs ) xs' ) ) uty
substU
function can substitute for either kind of variable in a
(right now we only need it to substitute for type variables, but we will need it to substitute for unification variables later). Of course, it is implemented via
. In the variable cases we make sure to leave the variable alone if it is not a key in the given substitution mapping. In the recursive non-variable case, we just roll up the
TypeF UType
by applying
. This is the power of
: we can deal with all the boring recursive cases in one fell swoop. This function won’t have to change if we add new types to the language in the future.
> substU :: Map ( Either Var IntVar ) UType -> UType -> UType > substU m = ucata > ( \ v -> fromMaybe ( UVar v ) ( M.lookup ( Right v ) m ) ) > ( \ case > TyVarF v -> fromMaybe ( UTyVar v ) ( M.lookup ( Left v ) m ) > f -> UTerm f > )
There is one other way to convert a
to a
, which happens when we want to check that an expression has a polymorphic type specified by the user. For example,
let foo : forall a. a -> a = \x.3 in ...
should be a type error, because the user specified that
foo
should have type
, but then gave the implementation
\x.3
which is too specific. In this situation we can’t just
the polytype—that would create a unification variable for
, and while typechecking
it would unify
nat
. But in this case we don’t want
to unify with
—it has to be held entirely abstract, because the user’s claim is that this function will work for any type
Instead of generating unification variables, we instead want to generate what are known as Skolem variables. Skolem variables do not unify with anything other than themselves. So how can we get
to do that? It does not have any built-in notion of Skolem variables. What we can do instead is to just embed the variables within the
as
UTyVar
s instead of
UVar
s!
does not even know those are variables; it just sees them as another rigid part of the structure that must be matched exactly, just as a
TyFun
has to match another
. The one remaining issue is that we need to generate fresh Skolem variables; it certainly would not do to have them collide with Skolem variables from some other
. We could carry around our own supply of unique names in the
monad for this purpose, which would probably be the “proper” way to do things; but for now I did something more expedient: just get
to generate fresh unification variables, then rip the (unique! fresh!)
s out of them and use those to make our Skolem variables.
> skolemize :: UPolytype -> Infer UType > skolemize ( Forall xs uty ) = do > xs' <- mapM ( const fresh ) xs > return $ substU ( M.fromList ( zip ( map Left xs ) ( map toSkolem xs' ) ) ) uty > where > toSkolem ( UVar v ) = UTyVar ( mkVarName "s" v )
When
generates fresh
IntVars
it seems that it starts at
minBound :: Int
and increments, so we can add
maxBound + 1
to get numbers that look reasonable. Again, this is not very principled, but for this toy example, who cares?
> mkVarName :: String -> IntVar -> Var > mkVarName nm ( IntVar v ) = nm ++ show ( v + ( maxBound :: Int ) + 1 )
Next, how do we convert from a
back to a
? This happens when we have inferred the type of a
-bound variable and go to put it in the context; typically, Hindley-Milner systems generalize the inferred type to a polytype. If a unification variable still occurs free in a type, it means it was not constrained at all, so we can universally quantify over it. However, we have to be careful: unification variables that occur in some type that is already in the context do not count as free, because we might later discover that they need to be constrained.
Also, just before we do the generalization, it’s very important that we use
has been collecting a substitution from unification variables to types, but for efficiency’s sake it does not actually apply the substitution until we ask it to, by calling
. Any unification variables which still remain after
really are unconstrained so far. So after
, we get the free unification variables from the type, subtract off any unification variables which are free in the context, and close over the remaining variables with a
, substituting normal type variables for them. It does not particularly matter if these type variables are fresh (so long as they are distinct). But we can’t look only at unification variables! We have to look at free type variables too (this is the reason that our
function needs to find both free type and unification variables). Why is that? Well, we might have some free type variables floating around if we previously generated some Skolem variables while checking a polymorphic type. (A term which illustrates this behavior is
\y. let x : forall a. a -> a = y in x 3
.) Free Skolem variables should also be generalized over.
> generalize :: UType -> Infer UPolytype > generalize uty = do > uty' <- applyBindings uty > ctx <- ask > tmfvs <- freeVars uty' > ctxfvs <- freeVars ctx > let fvs = S.toList $ tmfvs \\ ctxfvs > xs = map ( either id ( mkVarName "a" ) ) fvs > return $ Forall xs ( substU ( M.fromList ( zip fvs ( map UTyVar xs ) ) ) uty' )
Finally, we need a way to convert
s entered by the user into
UPolytypes
, and a way to convert the final
back into a
provides functions
unfreeze : Fix t -> UTerm t v
freeze : UTerm t v -> Maybe (Fix t)
to convert between terms built with
(with unification variables) and
(without unification variables). Converting to
is easy: we just use
unfreeze
to convert the underlying
> toUPolytype :: Polytype -> UPolytype > toUPolytype = fmap unfreeze
When converting back, notice that
returns a
Maybe
; it fails if there are any unification variables remaining. So we must be careful to only use
fromUPolytype
when we know there are no unification variables remaining in a polytype. In fact, we will use this only at the very top level, after generalizing the type that results from inference over a top-level term. Since at the top level we only perform inference on closed terms, in an empty type context, the final
generalize
step will generalize over all the remaining free unification variables, since there will be no free variables in the context.
> fromUPolytype :: UPolytype -> Polytype > fromUPolytype = fmap ( fromJust . freeze )
Type inference
Finally, the type inference algorithm proper! First, to
check
that an expression has a given type, we
infer
the type of the expression and then demand (via
) that the inferred type must be equal to the given one. Note that
actually returns a
, and it can apparently be more efficient to use the result of
in preference to either of the arguments to it (although they will all give equivalent results). However, in our case this doesn’t seem to make much difference.
> check :: Expr -> UType -> Infer () > check e ty = do > ty' <- infer e > _ <- ty =:= ty' > return ()
And now for the
function. The
EVar
,
EInt
EPlus
cases are straightforward.
> infer :: Expr -> Infer UType > infer ( EVar x ) = lookup x > infer ( EInt _ ) = return UTyNat > infer ( EPlus e1 e2 ) = do > check e1 UTyNat > check e2 UTyNat > return UTyNat
For an application
EApp e1 e2
, we infer the types
funTy
argTy
of
e1
respectively, then demand that
funTy =:= UTyFun argTy resTy
for a fresh unification variable
resTy
. Again,
returns a more efficient
which is equivalent to
, but we don’t need to use that type directly (we return
instead), so we just discard the result.
> infer ( EApp e1 e2 ) = do > funTy <- infer e1 > argTy <- infer e2 > resTy <- fresh > _ <- funTy =:= UTyFun argTy resTy > return resTy
For a lambda, we make up a fresh unification variable for the type of the argument, then infer the type of the body under an extended context. Notice how we promote the freshly generated unification variable to a
by wrapping it in
Forall []
; we do not
it, since that would turn it into
forall a. a
! We just want the lambda argument to have the bare unification variable as its type.
> infer ( ELam x body ) = do > argTy <- fresh > withBinding x ( Forall [] argTy ) $ do > resTy <- infer body > return $ UTyFun argTy resTy
For a
expression without a type annotation, we infer the type of the definition, then generalize it and add it to the context to infer the type of the body. It is traditional for Hindley-Milner systems to generalize
-bound things this way (although note that GHC does not generalize let -bound things with
-XMonoLocalBinds
enabled, which is automatically implied by
-XGADTs
-XTypeFamilies
> infer ( ELet x Nothing xdef body ) = do > ty <- infer xdef > pty <- generalize ty > withBinding x pty $ infer body
expression with a type annotation, we
skolemize
it and
the definition with the skolemized type; the rest is the same as the previous case.
> infer ( ELet x ( Just pty ) xdef body ) = do > let upty = toUPolytype pty > upty' <- skolemize upty > check xdef upty' > withBinding x upty $ infer body
Running the
monad
We need a way to run computations in the
monad. This is a bit fiddly, and it took me a long time to put all the pieces together. (But typed holes are sooooo great! It would have taken me way longer without them…) I’ve written the definition of
runInfer
using the backwards function composition operator,
(>>>)
, so that the pipeline flows from top to bottom and I can intersperse it with explanation.
> runInfer :: Infer UType -> Either TypeError Polytype > runInfer
The first thing we do is use
to make sure that we substitute for any unification variables that we know about. This results in another
Infer UType
> = ( >>= applyBindings )
We can now generalize over any unification variables that are left, and then convert from
. Again, this conversion is safe because at this top level we know we will be in an empty context, so the generalization step will definitely get rid of all the remaining unification variables.
> >>> ( >>= ( generalize >>> fmap fromUPolytype ) )
Now all that’s left is to interpret the layers of our
monad one by one. As promised, we start with an empty type context.
> >>> flip runReaderT M.empty > >>> runExceptT > >>> evalIntBindingT > >>> runIdentity
Finally, we can make a top-level function to infer a polytype for an expression, just by composing
> inferPolytype :: Expr -> Either TypeError Polytype > inferPolytype = runInfer . infer
REPL
To be able to test things out, we can make a very simple REPL that takes input from the user and tries to parse, typecheck, and interpret it, printing either the results or an appropriate error message.
> eval :: String -> IO () > eval s = case parse expr "" s of > Left err -> print err > Right e -> case inferPolytype e of > Left tyerr -> putStrLn $ pretty tyerr > Right ty -> do > putStrLn $ pretty e ++ " : " ++ pretty ty > when ( ty == Forall [] TyNat ) $ putStrLn $ pretty ( interp e ) > > main :: IO () > main = evalRepl ( const ( pure "HM> " ) ) ( liftIO . eval ) [] Nothing Nothing ( Word ( const ( return [] ) ) ) ( return () ) ( return Exit )
Here are a few examples to try out:
HM> 2 + 3 2 + 3 : nat 5 HM> \x. x \x. x : forall a0. a0 -> a0 HM> \x.3 \x. 3 : forall a0. a0 -> nat HM> \x. x + 1 \x. x + 1 : nat -> nat HM> (\x. 3) (\y.y) (\x. 3) (\y. y) : nat 3 HM> \x. y Unbound variable y HM> \x. x x Infinite type u0 = u0 -> u1 HM> 3 3 Can't unify nat and nat -> u0 HM> let foo : forall a. a -> a = \x.3 in foo 5 Can't unify s0 and nat HM> \f.\g.\x. f (g x) \f. \g. \x. f (g x) : forall a2 a3 a4. (a3 -> a4) -> (a2 -> a3) -> a2 -> a4 HM> let f : forall a. a -> a = \x.x in let y : forall b. b -> b -> b = \z.\q. f z in y 2 3 let f : forall a. a -> a = \x. x in let y : forall b. b -> b -> b = \z. \q. f z in y 2 3 : nat 2 HM> \y. let x : forall a. a -> a = y in x 3 \y. let x : forall a. a -> a = y in x 3 : forall s1. (s1 -> s1) -> nat HM> (\x. let y = x in y) (\z. \q. z) (\x. let y = x in y) (\z. \q. z) : forall a1 a2. a1 -> a2 -> a1
And that’s it! Feel free to play around with this yourself, and adapt the code for your own projects if it’s helpful. And please do report any typos or bugs that you find.
Below, for completeness, you will find a simple, recursive, environment-passing interpreter, along with code for parsing and pretty-printing. I don’t give any commentary on them because, for the most part, they are straightforward and have nothing to do with
. But you are certainly welcome to look at them if you want to see how they work. The one interesting thing to say about the parser for types is that it checks that types entered by the user do not contain any free variables, and fails if they do. The parser is not really the right place to do this check, but again, it was expedient for this toy example. Also, I tend to use
megaparsec
for serious projects, but I had some
parsec
code for parsing something similar to this toy language lying around, so I just reused that.
Interpreter
> data Value where > VInt :: Integer -> Value > VClo :: Var -> Expr -> Env -> Value > > type Env = Map Var Value > > interp :: Expr -> Value > interp = interp' M.empty > > interp' :: Env -> Expr -> Value > interp' env ( EVar x ) = fromJust $ M.lookup x env > interp' _ ( EInt n ) = VInt n > interp' env ( EPlus ea eb ) = > case ( interp' env ea , interp' env eb ) of > ( VInt va , VInt vb ) -> VInt ( va + vb ) > _ -> error "Impossible! interp' EPlus on non-Ints" > interp' env ( ELam x body ) = VClo x body env > interp' env ( EApp fun arg ) = > case interp' env fun of > VClo x body env' -> > interp' ( M.insert x ( interp' env arg ) env' ) body > _ -> error "Impossible! interp' EApp on non-closure" > interp' env ( ELet x _ xdef body ) = > let xval = interp' env xdef > in interp' ( M.insert x xval env ) body
Parsing
> lexer :: L.TokenParser u > lexer = L.makeTokenParser emptyDef > { L.reservedNames = [ "let" , "in" , "forall" , "nat" ] } > > parens :: Parser a -> Parser a > parens = L.parens lexer > > identifier :: Parser String > identifier = L.identifier lexer > > reserved :: String -> Parser () > reserved = L.reserved lexer > > reservedOp :: String -> Parser () > reservedOp = L.reservedOp lexer > > symbol :: String -> Parser String > symbol = L.symbol lexer > > integer :: Parser Integer > integer = L.natural lexer > > parseAtom :: Parser Expr > parseAtom > = EVar <$> identifier > <|> EInt <$> integer > <|> ELam <$> ( symbol "\\" *> identifier ) > <*> ( symbol "." *> parseExpr ) > <|> ELet <$> ( reserved "let" *> identifier ) > <*> optionMaybe ( symbol ":" *> parsePolytype ) > <*> ( symbol "=" *> parseExpr ) > <*> ( reserved "in" *> parseExpr ) > <|> parens parseExpr > > parseApp :: Parser Expr > parseApp = chainl1 parseAtom ( return EApp ) > > parseExpr :: Parser Expr > parseExpr = buildExpressionParser table parseApp > where > table = [ [ Infix ( EPlus <$ reservedOp "+" ) AssocLeft ] > ] > > parsePolytype :: Parser Polytype > parsePolytype = do > pty @ ( Forall xs ty ) <- parsePolytype' > let fvs :: Set Var > fvs = flip cata ty $ \ case > TyVarF x -> S.singleton x > TyNatF -> S.empty > TyFunF xs1 xs2 -> xs1 `S.union` xs2 > unbound = fvs \\ S.fromList xs > unless ( S.null unbound ) $ fail $ "Unbound type variables: " ++ unwords ( S.toList unbound ) > return pty > > parsePolytype' :: Parser Polytype > parsePolytype' = > Forall <$> ( fromMaybe [] <$> optionMaybe ( reserved "forall" *> many identifier <* symbol "." ) ) > <*> parseType > > parseTypeAtom :: Parser Type > parseTypeAtom = > ( TyNat <$ reserved "nat" ) <|> ( TyVar <$> identifier ) <|> parens parseType > > parseType :: Parser Type > parseType = buildExpressionParser table parseTypeAtom > where > table = [ [ Infix ( TyFun <$ symbol "->" ) AssocRight ] ] > > expr :: Parser Expr > expr = spaces *> parseExpr <* eof
Pretty printing
> type Prec = Int > > class Pretty p where > pretty :: p -> String > pretty = prettyPrec 0 > > prettyPrec :: Prec -> p -> String > prettyPrec _ = pretty > > instance Pretty ( t ( Fix t ) ) => Pretty ( Fix t ) where > prettyPrec p = prettyPrec p . unFix > > instance Pretty t => Pretty ( TypeF t ) where > prettyPrec _ ( TyVarF v ) = v > prettyPrec _ TyNatF = "nat" > prettyPrec p ( TyFunF ty1 ty2 ) = > mparens ( p > 0 ) $ prettyPrec 1 ty1 ++ " -> " ++ prettyPrec 0 ty2 > > instance ( Pretty ( t ( UTerm t v ) ) , Pretty v ) => Pretty ( UTerm t v ) where > pretty ( UTerm t ) = pretty t > pretty ( UVar v ) = pretty v > > instance Pretty Polytype where > pretty ( Forall [] t ) = pretty t > pretty ( Forall xs t ) = unwords ( "forall" : xs ) ++ ". " ++ pretty t > > mparens :: Bool -> String -> String > mparens True = ( "(" ++ ) . ( ++ ")" ) > mparens False = id > > instance Pretty Expr where > prettyPrec _ ( EVar x ) = x > prettyPrec _ ( EInt i ) = show i > prettyPrec p ( EPlus e1 e2 ) = > mparens ( p > 1 ) $ > prettyPrec 1 e1 ++ " + " ++ prettyPrec 2 e2 > prettyPrec p ( ELam x body ) = > mparens ( p > 0 ) $ > "\\" ++ x ++ ". " ++ prettyPrec 0 body > prettyPrec p ( ELet x mty xdef body ) = > mparens ( p > 0 ) $ > "let " ++ x ++ maybe "" ( \ ty -> " : " ++ pretty ty ) mty > ++ " = " ++ prettyPrec 0 xdef > ++ " in " ++ prettyPrec 0 body > prettyPrec p ( EApp e1 e2 ) = > mparens ( p > 2 ) $ > prettyPrec 2 e1 ++ " " ++ prettyPrec 3 e2 > > instance Pretty IntVar where > pretty = mkVarName "u" > > instance Pretty TypeError where > pretty ( UnboundVar x ) = printf "Unbound variable %s" x > pretty ( Infinite x ty ) = printf "Infinite type %s = %s" ( pretty x ) ( pretty ty ) > pretty ( Mismatch ty1 ty2 ) = printf "Can't unify %s and %s" ( pretty ty1 ) ( pretty ty2 ) > > instance Pretty Value where > pretty ( VInt n ) = show n > pretty ( VClo x body env ) > = printf "<%s: %s %s>" > x ( pretty body ) ( pretty env ) > > instance Pretty Env where > pretty env = "[" ++ intercalate ", " bindings ++ "]" > where > bindings = map prettyBinding ( M.assocs env ) > prettyBinding ( x , v ) = x ++ " -> " ++ pretty v
Posted in haskell , teaching | Tagged Hindley-Milner , inference , types , unification | Leave a comment
Implementing Hindley-Milner with the unification-fd library