{-# LANGUAGE LambdaCase, FlexibleContexts #-}

-- | Module providing semantic analysis and type checking for a simple functional language.
-- It includes desugaring of top-level lambdas, semantic error checking, and Hindley-Milner
-- type inference with support for Case, Let, and generic application.
module Semantic where

import Syntax
import qualified Data.Map as Map
import qualified Data.Set as Set
import Control.Monad.Except
import Control.Monad.State
import Control.Monad (forM, when, foldM)

-- ======================================================
-- 1) Desugaring: transform top-level lambdas into parameters
-- ======================================================

-- | Transform a top-level function declaration with a lambda body
-- into an equivalent declaration with explicit parameters.
desugarDecl :: Decl -> Decl
-- | Perform desugaring only when the function has no parameters
-- and its body is a lambda abstraction.
desugarDecl :: Decl -> Decl
desugarDecl (FunDecl Ident
f [] (Lambda [Ident]
ps Expr
e)) = Ident -> [Ident] -> Expr -> Decl
FunDecl Ident
f [Ident]
ps Expr
e
-- | Leave other declarations unchanged.
desugarDecl Decl
d                           = Decl
d

-- | Apply desugaring to all declarations in a program.
desugarProgram :: Program -> Program
-- | Desugar each declaration in the program.
desugarProgram :: Program -> Program
desugarProgram (Program [Decl]
ds) = [Decl] -> Program
Program ((Decl -> Decl) -> [Decl] -> [Decl]
forall a b. (a -> b) -> [a] -> [b]
map Decl -> Decl
desugarDecl [Decl]
ds)

-- ======================================================
-- 2) Errors (semantic and type errors)
-- ======================================================

-- | Semantic errors detected during scope and arity checking.
data SemanticError
  = UndefinedVar Ident          -- ^ Variable used without definition
  | ArityMismatch Ident Int Int -- ^ Function called with wrong number of arguments
  | DuplicateFunc Ident         -- ^ Function name defined more than once
  | DuplicateParam Ident        -- ^ Parameter name appears multiple times in declaration
  | DuplicatePatternVar Ident   -- ^ Pattern variable appears multiple times in the same pattern
  deriving (Int -> SemanticError -> ShowS
[SemanticError] -> ShowS
SemanticError -> Ident
(Int -> SemanticError -> ShowS)
-> (SemanticError -> Ident)
-> ([SemanticError] -> ShowS)
-> Show SemanticError
forall a.
(Int -> a -> ShowS) -> (a -> Ident) -> ([a] -> ShowS) -> Show a
showList :: [SemanticError] -> ShowS
$cshowList :: [SemanticError] -> ShowS
show :: SemanticError -> Ident
$cshow :: SemanticError -> Ident
showsPrec :: Int -> SemanticError -> ShowS
$cshowsPrec :: Int -> SemanticError -> ShowS
Show, SemanticError -> SemanticError -> Bool
(SemanticError -> SemanticError -> Bool)
-> (SemanticError -> SemanticError -> Bool) -> Eq SemanticError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SemanticError -> SemanticError -> Bool
$c/= :: SemanticError -> SemanticError -> Bool
== :: SemanticError -> SemanticError -> Bool
$c== :: SemanticError -> SemanticError -> Bool
Eq)

-- | Type errors detected during type inference and checking.
data TypeError
  = Mismatch Expr Type Type       -- ^ Expression has unexpected type
  | CondNotBool Expr Type         -- ^ Condition expression is not boolean
  | BranchesTypeDiffer Expr Expr Type Type -- ^ Then/else branches have different types
  | BinOpTypeErr BinOperator Type Type -- ^ Binary operator applied to incompatible types
  | UnOpTypeErr UnOperator Type   -- ^ Unary operator applied to non-matching type
  | UnknownVar Ident              -- ^ Variable not found in type environment
  deriving Int -> TypeError -> ShowS
[TypeError] -> ShowS
TypeError -> Ident
(Int -> TypeError -> ShowS)
-> (TypeError -> Ident) -> ([TypeError] -> ShowS) -> Show TypeError
forall a.
(Int -> a -> ShowS) -> (a -> Ident) -> ([a] -> ShowS) -> Show a
showList :: [TypeError] -> ShowS
$cshowList :: [TypeError] -> ShowS
show :: TypeError -> Ident
$cshow :: TypeError -> Ident
showsPrec :: Int -> TypeError -> ShowS
$cshowsPrec :: Int -> TypeError -> ShowS
Show

-- | Combined error type for semantic or type errors.
data Error
  = SemErr SemanticError  -- ^ A semantic error occurred
  | TypErr TypeError      -- ^ A type error occurred
  deriving Int -> Error -> ShowS
[Error] -> ShowS
Error -> Ident
(Int -> Error -> ShowS)
-> (Error -> Ident) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> Ident) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> Ident
$cshow :: Error -> Ident
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show

-- ======================================================
-- 3) Environments and signatures
-- ======================================================

-- | Signature mapping function identifiers to their arity.
type Sig     = Map.Map Ident Int

-- | Environment of variables currently in scope.
type Env     = Set.Set Ident

-- | Types in the language, including base types, lists, tuples, type variables, and function types.
data Type
  = TInt | TFloat | TBool | TChar | TString    -- ^ Primitive types
  | TList Type       -- ^ Homogeneous list types
  | TTuple [Type]    -- ^ Tuple types with fixed arity
  | TVar String      -- ^ Type variable for inference
  | TFun [Type] Type -- ^ Function type with argument types and return type
  deriving (Type -> Type -> Bool
(Type -> Type -> Bool) -> (Type -> Type -> Bool) -> Eq Type
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Type -> Type -> Bool
$c/= :: Type -> Type -> Bool
== :: Type -> Type -> Bool
$c== :: Type -> Type -> Bool
Eq, Int -> Type -> ShowS
[Type] -> ShowS
Type -> Ident
(Int -> Type -> ShowS)
-> (Type -> Ident) -> ([Type] -> ShowS) -> Show Type
forall a.
(Int -> a -> ShowS) -> (a -> Ident) -> ([a] -> ShowS) -> Show a
showList :: [Type] -> ShowS
$cshowList :: [Type] -> ShowS
show :: Type -> Ident
$cshow :: Type -> Ident
showsPrec :: Int -> Type -> ShowS
$cshowsPrec :: Int -> Type -> ShowS
Show)

-- | Typing environment: map from identifiers to their inferred types.
type TypeEnv = Map.Map Ident Type

-- | Function environment: map from top-level function names to their argument types and return type.
type FuncEnv = Map.Map Ident ([Type], Type)

-- | State for generating fresh type variables during inference.
data InferState = InferState { InferState -> Int
count :: Int }

-- | Inference monad combining state for fresh variables and error handling.
type Infer a    = ExceptT TypeError (State InferState) a

-- ======================================================
-- 4) Building signatures and function environment
-- ======================================================

-- | Build a signature from a list of function declarations.
-- Records arity for each function, ignoring duplicates.
buildSig :: [Decl] -> Sig
buildSig :: [Decl] -> Sig
buildSig = (Decl -> Sig -> Sig) -> Sig -> [Decl] -> Sig
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(FunDecl Ident
f [Ident]
ps Expr
_) Sig
acc ->
                    (Int -> Int -> Int) -> Ident -> Int -> Sig -> Sig
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith ((Int -> Int) -> Int -> Int -> Int
forall a b. a -> b -> a
const Int -> Int
forall a. a -> a
id) Ident
f ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
ps) Sig
acc)
                 Sig
forall k a. Map k a
Map.empty

-- | Create an initial function environment using fresh type variables for arguments and return.
buildFuncEnv :: [Decl] -> FuncEnv
buildFuncEnv :: [Decl] -> FuncEnv
buildFuncEnv = [(Ident, ([Type], Type))] -> FuncEnv
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Ident, ([Type], Type))] -> FuncEnv)
-> ([Decl] -> [(Ident, ([Type], Type))]) -> [Decl] -> FuncEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Decl -> (Ident, ([Type], Type)))
-> [Decl] -> [(Ident, ([Type], Type))]
forall a b. (a -> b) -> [a] -> [b]
map (\(FunDecl Ident
f [Ident]
args Expr
_) ->
  let tvs :: [Type]
tvs = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
args) (Ident -> Type
TVar Ident
"_")
      tr :: Type
tr  = Ident -> Type
TVar (Ident
"r_" Ident -> ShowS
forall a. [a] -> [a] -> [a]
++ Ident
f)
  in (Ident
f, ([Type]
tvs, Type
tr)))

-- ======================================================
-- 5) Semantic checking
-- ======================================================

-- | Perform semantic checks (undefined variables, arity, duplicates) on a program.
semanticCheck :: Program -> [Error]
semanticCheck :: Program -> [Error]
semanticCheck Program
prog =
  let Program [Decl]
ds = Program -> Program
desugarProgram Program
prog
      sig0 :: Sig
sig0  = [Decl] -> Sig
buildSig [Decl]
ds
      dupFs :: [Error]
dupFs = [ SemanticError -> Error
SemErr (Ident -> SemanticError
DuplicateFunc Ident
f) | (Ident
f,Int
n) <- Sig -> [(Ident, Int)]
forall k a. Map k a -> [(k, a)]
Map.toList Sig
sig0, Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 ]
      errs :: [SemanticError]
errs  = (Decl -> [SemanticError]) -> [Decl] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Sig -> Decl -> [SemanticError]
checkDecl Sig
sig0) [Decl]
ds
  in [Error]
dupFs [Error] -> [Error] -> [Error]
forall a. [a] -> [a] -> [a]
++ (SemanticError -> Error) -> [SemanticError] -> [Error]
forall a b. (a -> b) -> [a] -> [b]
map SemanticError -> Error
SemErr [SemanticError]
errs

-- | Check a single function declaration for semantic errors.
checkDecl :: Sig -> Decl -> [SemanticError]
checkDecl :: Sig -> Decl -> [SemanticError]
checkDecl Sig
sig (FunDecl Ident
_ [Ident]
ps Expr
b) =
  let env0 :: Set Ident
env0      = [Ident] -> Set Ident
forall a. Ord a => [a] -> Set a
Set.fromList [Ident]
ps
      dupParams :: [SemanticError]
dupParams = [ Ident -> SemanticError
DuplicateParam Ident
x | Ident
x <- [Ident]
ps, [Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Ident -> Bool) -> [Ident] -> [Ident]
forall a. (a -> Bool) -> [a] -> [a]
filter (Ident -> Ident -> Bool
forall a. Eq a => a -> a -> Bool
==Ident
x) [Ident]
ps) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 ]
  in [SemanticError]
dupParams [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env0 Expr
b

-- | Recursively check an expression for semantic errors given current signature and environment.
checkExpr :: Sig -> Env -> Expr -> [SemanticError]
checkExpr :: Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env Expr
expr = case Expr
expr of
  Var Ident
x
    | Ident -> Set Ident -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Ident
x Set Ident
env Bool -> Bool -> Bool
|| Ident -> Sig -> Bool
forall k a. Ord k => k -> Map k a -> Bool
Map.member Ident
x Sig
sig -> []
    | Bool
otherwise                            -> [Ident -> SemanticError
UndefinedVar Ident
x]
  Lit Literal
_       -> []
  Lambda [Ident]
ps Expr
e ->
    let dup :: [SemanticError]
dup  = [ Ident -> SemanticError
DuplicateParam Ident
p | Ident
p <- [Ident]
ps, [Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Ident -> Bool) -> [Ident] -> [Ident]
forall a. (a -> Bool) -> [a] -> [a]
filter (Ident -> Ident -> Bool
forall a. Eq a => a -> a -> Bool
==Ident
p) [Ident]
ps) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 ]
        env' :: Set Ident
env' = Set Ident -> Set Ident -> Set Ident
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Ident
env ([Ident] -> Set Ident
forall a. Ord a => [a] -> Set a
Set.fromList [Ident]
ps)
    in [SemanticError]
dup [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env' Expr
e
  If Expr
c Expr
t Expr
e    -> (Expr -> [SemanticError]) -> [Expr] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env) [Expr
c,Expr
t,Expr
e]
  Case Expr
s [(Pattern, Expr)]
alts -> Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env Expr
s [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ ((Pattern, Expr) -> [SemanticError])
-> [(Pattern, Expr)] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Sig -> Set Ident -> (Pattern, Expr) -> [SemanticError]
checkAlt Sig
sig Set Ident
env) [(Pattern, Expr)]
alts
  Let [Decl]
ds Expr
e    ->
    let sig' :: Sig
sig' = Sig -> Sig -> Sig
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ([Decl] -> Sig
buildSig [Decl]
ds) Sig
sig
        errsD :: [SemanticError]
errsD = (Decl -> [SemanticError]) -> [Decl] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(FunDecl Ident
_ [Ident]
ps Expr
bd) ->
                  Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig' (Set Ident -> Set Ident -> Set Ident
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Ident
env ([Ident] -> Set Ident
forall a. Ord a => [a] -> Set a
Set.fromList [Ident]
ps)) Expr
bd
                ) [Decl]
ds
        env' :: Set Ident
env'  = Set Ident -> Set Ident -> Set Ident
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Ident
env ([Ident] -> Set Ident
forall a. Ord a => [a] -> Set a
Set.fromList ((Decl -> [Ident]) -> [Decl] -> [Ident]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(FunDecl Ident
_ [Ident]
ps Expr
_) -> [Ident]
ps) [Decl]
ds))
    in [SemanticError]
errsD [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig' Set Ident
env' Expr
e
  App{}       ->
    let (Expr
fn, [Expr]
args) = Expr -> (Expr, [Expr])
flattenApp Expr
expr
        e1 :: [SemanticError]
e1 = Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env Expr
fn
        e2 :: [SemanticError]
e2 = (Expr -> [SemanticError]) -> [Expr] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env) [Expr]
args
        ar :: [SemanticError]
ar = case Expr
fn of
          Var Ident
f | Just Int
ar <- Ident -> Sig -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Ident
f Sig
sig, Int
ar Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Expr] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
args
                -> [Ident -> Int -> Int -> SemanticError
ArityMismatch Ident
f Int
ar ([Expr] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expr]
args)]
          Expr
_ -> []
    in [SemanticError]
e1 [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ [SemanticError]
e2 [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ [SemanticError]
ar
  BinOp BinOperator
_ Expr
l Expr
r  -> Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env Expr
l [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env Expr
r
  UnOp UnOperator
_ Expr
x     -> Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env Expr
x
  List [Expr]
xs      -> (Expr -> [SemanticError]) -> [Expr] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env) [Expr]
xs
  Tuple [Expr]
xs     -> (Expr -> [SemanticError]) -> [Expr] -> [SemanticError]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env) [Expr]
xs

-- | Check a case alternative for duplicate pattern variables and nested errors.
checkAlt :: Sig -> Env -> (Pattern, Expr) -> [SemanticError]
checkAlt :: Sig -> Set Ident -> (Pattern, Expr) -> [SemanticError]
checkAlt Sig
sig Set Ident
env (Pattern
pat, Expr
bd) =
  let vs :: [Ident]
vs   = Pattern -> [Ident]
patVars Pattern
pat
      dupV :: [SemanticError]
dupV = [ Ident -> SemanticError
DuplicatePatternVar Ident
v | Ident
v <- [Ident]
vs, [Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((Ident -> Bool) -> [Ident] -> [Ident]
forall a. (a -> Bool) -> [a] -> [a]
filter (Ident -> Ident -> Bool
forall a. Eq a => a -> a -> Bool
==Ident
v) [Ident]
vs) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 ]
      env' :: Set Ident
env' = Set Ident -> Set Ident -> Set Ident
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Ident
env ([Ident] -> Set Ident
forall a. Ord a => [a] -> Set a
Set.fromList [Ident]
vs)
  in [SemanticError]
dupV [SemanticError] -> [SemanticError] -> [SemanticError]
forall a. [a] -> [a] -> [a]
++ Sig -> Set Ident -> Expr -> [SemanticError]
checkExpr Sig
sig Set Ident
env' Expr
bd

-- | Flatten nested applications into function and argument list.
flattenApp :: Expr -> (Expr, [Expr])
flattenApp :: Expr -> (Expr, [Expr])
flattenApp (App Expr
f Expr
x) = let (Expr
fn,[Expr]
xs) = Expr -> (Expr, [Expr])
flattenApp Expr
f in (Expr
fn, [Expr]
xs [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr
x])
flattenApp Expr
e         = (Expr
e, [])

-- | Extract variables from a pattern.
patVars :: Pattern -> [Ident]
patVars :: Pattern -> [Ident]
patVars = \case
  Pattern
PWildcard   -> []
  PVar Ident
x      -> [Ident
x]
  PLit Literal
_      -> []
  PList [Pattern]
ps    -> (Pattern -> [Ident]) -> [Pattern] -> [Ident]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Pattern -> [Ident]
patVars [Pattern]
ps
  PTuple [Pattern]
ps   -> (Pattern -> [Ident]) -> [Pattern] -> [Ident]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Pattern -> [Ident]
patVars [Pattern]
ps

-- ======================================================
-- 6) Type checking and inference
-- ======================================================

-- | Perform both semantic and type checking on a program.
checkProgram :: Program -> [Error]
checkProgram :: Program -> [Error]
checkProgram Program
prog =
  let Program [Decl]
ds = Program -> Program
desugarProgram Program
prog
      fenv :: FuncEnv
fenv        = [Decl] -> FuncEnv
buildFuncEnv [Decl]
ds
  in (Decl -> [Error]) -> [Decl] -> [Error]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (FuncEnv -> Map Ident Type -> Decl -> [Error]
runDecl FuncEnv
fenv Map Ident Type
forall k a. Map k a
Map.empty) [Decl]
ds
  where
    -- | Infer and unify return type for a function declaration.
    runDecl :: FuncEnv -> Map Ident Type -> Decl -> [Error]
runDecl FuncEnv
fenv Map Ident Type
tenv (FunDecl Ident
f [Ident]
ps Expr
b) =
      let argTys :: [Type]
argTys = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
ps) (Ident -> Type
TVar Ident
"_")
          retTy :: Type
retTy  = Ident -> Type
TVar (Ident
"r_" Ident -> ShowS
forall a. [a] -> [a] -> [a]
++ Ident
f)
          env' :: Map Ident Type
env'   = Map Ident Type -> Map Ident Type -> Map Ident Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ([(Ident, Type)] -> Map Ident Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([Ident] -> [Type] -> [(Ident, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
ps [Type]
argTys)) Map Ident Type
tenv
          act :: ExceptT TypeError (State InferState) Type
act    = FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
env' Expr
b ExceptT TypeError (State InferState) Type
-> (Type -> ExceptT TypeError (State InferState) Type)
-> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Type -> Type -> ExceptT TypeError (State InferState) Type
unifyReturn Type
retTy
          st0 :: InferState
st0    = Int -> InferState
InferState Int
0
      in case State InferState (Either TypeError Type)
-> InferState -> (Either TypeError Type, InferState)
forall s a. State s a -> s -> (a, s)
runState (ExceptT TypeError (State InferState) Type
-> State InferState (Either TypeError Type)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT TypeError (State InferState) Type
act) InferState
st0 of
           (Left TypeError
te, InferState
_) -> [TypeError -> Error
TypErr TypeError
te]
           (Either TypeError Type, InferState)
_            -> []

-- | Unify expected and actual return types, allowing type variables.
unifyReturn :: Type -> Type -> Infer Type
unifyReturn :: Type -> Type -> ExceptT TypeError (State InferState) Type
unifyReturn Type
expected Type
actual
  | TVar Ident
_ <- Type
expected = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
actual
  | Type
expected Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
actual = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
actual
  | Bool
otherwise          = TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch (Literal -> Expr
Lit (Ident -> Literal
LString Ident
"return")) Type
expected Type
actual)

-- | Infer the type of an expression.
inferExpr :: FuncEnv -> TypeEnv -> Expr -> Infer Type
inferExpr :: FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
expr = case Expr
expr of
  Var Ident
x -> case Ident -> Map Ident Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Ident
x Map Ident Type
tenv of
             Just Type
t  -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
             Maybe Type
Nothing -> case Ident -> FuncEnv -> Maybe ([Type], Type)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Ident
x FuncEnv
fenv of
                          Just ([Type]
argTys, Type
retT) -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type -> Type
TFun [Type]
argTys Type
retT)
                          Maybe ([Type], Type)
Nothing             -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Ident -> TypeError
UnknownVar Ident
x)
  Lit Literal
l -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> ExceptT TypeError (State InferState) Type)
-> Type -> ExceptT TypeError (State InferState) Type
forall a b. (a -> b) -> a -> b
$ Literal -> Type
literalType Literal
l

  Lambda [Ident]
ps Expr
bd -> do
    [Type]
tys <- (Ident -> ExceptT TypeError (State InferState) Type)
-> [Ident] -> ExceptT TypeError (State InferState) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ExceptT TypeError (State InferState) Type
-> Ident -> ExceptT TypeError (State InferState) Type
forall a b. a -> b -> a
const ExceptT TypeError (State InferState) Type
freshTypeVar) [Ident]
ps
    let tenv' :: Map Ident Type
tenv' = Map Ident Type -> Map Ident Type -> Map Ident Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ([(Ident, Type)] -> Map Ident Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([Ident] -> [Type] -> [(Ident, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
ps [Type]
tys)) Map Ident Type
tenv
    Type
tr <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv' Expr
bd
    Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Type -> Type
TFun [Type]
tys Type
tr)

  If Expr
c Expr
t Expr
e -> do
    ()
_  <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
c ExceptT TypeError (State InferState) Type
-> (Type -> ExceptT TypeError (State InferState) ())
-> ExceptT TypeError (State InferState) ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Expr -> Type -> ExceptT TypeError (State InferState) ()
ensureBool Expr
c
    Type
tc <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
t; Type
te <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
e
    case (Type
tc, Type
te) of
      (Type
a,Type
b) | Type
aType -> Type -> Bool
forall a. Eq a => a -> a -> Bool
==Type
b       -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
a
      (TVar Ident
_, Type
x)        -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
x
      (Type
x, TVar Ident
_)        -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
x
      (Type, Type)
_                  -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Expr -> Type -> Type -> TypeError
BranchesTypeDiffer Expr
t Expr
e Type
tc Type
te)

  Case Expr
scr [(Pattern, Expr)]
alts -> do
    Type
scrT <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
scr
    [Type]
rs   <- [(Pattern, Expr)]
-> ((Pattern, Expr) -> ExceptT TypeError (State InferState) Type)
-> ExceptT TypeError (State InferState) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Pattern, Expr)]
alts (((Pattern, Expr) -> ExceptT TypeError (State InferState) Type)
 -> ExceptT TypeError (State InferState) [Type])
-> ((Pattern, Expr) -> ExceptT TypeError (State InferState) Type)
-> ExceptT TypeError (State InferState) [Type]
forall a b. (a -> b) -> a -> b
$ \(Pattern
pat, Expr
bd') -> do
      ([(Ident, Type)]
vs,Type
pT) <- Pattern -> Infer ([(Ident, Type)], Type)
inferPattern Pattern
pat
      Bool
-> ExceptT TypeError (State InferState) ()
-> ExceptT TypeError (State InferState) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Type -> Bool
isPoly Type
scrT Bool -> Bool -> Bool
|| Type -> Bool
isPoly Type
pT) Bool -> Bool -> Bool
&& Type
pT Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
scrT)
        (ExceptT TypeError (State InferState) ()
 -> ExceptT TypeError (State InferState) ())
-> ExceptT TypeError (State InferState) ()
-> ExceptT TypeError (State InferState) ()
forall a b. (a -> b) -> a -> b
$ TypeError -> ExceptT TypeError (State InferState) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch Expr
scr Type
pT Type
scrT)
      FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv (Map Ident Type -> Map Ident Type -> Map Ident Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ([(Ident, Type)] -> Map Ident Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Ident, Type)]
vs) Map Ident Type
tenv) Expr
bd'
    case [Type]
rs of
      (Type
r0:[Type]
rs') -> do
        Type
t <- (Type -> Type -> ExceptT TypeError (State InferState) Type)
-> Type -> [Type] -> ExceptT TypeError (State InferState) Type
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Type
t1 Type
t2 ->
                if Type
t1Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
==Type
t2 then Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t1
                else if Type -> Bool
isPoly Type
t1 then Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t2
                else if Type -> Bool
isPoly Type
t2 then Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t1
                else TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Expr -> Type -> Type -> TypeError
BranchesTypeDiffer Expr
scr Expr
scr Type
t1 Type
t2)
              ) Type
r0 [Type]
rs'
        Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
      [] -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch Expr
scr Type
scrT Type
scrT)

  Let [Decl]
ds Expr
e -> do
    let fenv' :: FuncEnv
fenv' = FuncEnv -> FuncEnv -> FuncEnv
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ([Decl] -> FuncEnv
buildFuncEnv [Decl]
ds) FuncEnv
fenv
    Map Ident Type
tenv' <- (Map Ident Type
 -> Decl -> ExceptT TypeError (State InferState) (Map Ident Type))
-> Map Ident Type
-> [Decl]
-> ExceptT TypeError (State InferState) (Map Ident Type)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Map Ident Type
envAcc Decl
d -> case Decl
d of
                 FunDecl Ident
fn [] Expr
bd' -> do
                   Type
t <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
envAcc Expr
bd'
                   Map Ident Type
-> ExceptT TypeError (State InferState) (Map Ident Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> Type -> Map Ident Type -> Map Ident Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Ident
fn Type
t Map Ident Type
envAcc)
                 Decl
_ -> Map Ident Type
-> ExceptT TypeError (State InferState) (Map Ident Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Map Ident Type
envAcc
               ) Map Ident Type
tenv [Decl]
ds
    FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv' Map Ident Type
tenv' Expr
e

  App{} -> do
    let (Expr
fn,[Expr]
args) = Expr -> (Expr, [Expr])
flattenApp Expr
expr
    Type
fty   <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
fn
    [Type]
argTs <- (Expr -> ExceptT TypeError (State InferState) Type)
-> [Expr] -> ExceptT TypeError (State InferState) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv) [Expr]
args
    case Type
fty of
      TFun [Type]
ps Type
r
        | [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTs
          -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch Expr
expr ([Type] -> Type -> Type
TFun [Type]
ps Type
r) ([Type] -> Type -> Type
TFun [Type]
argTs Type
r))
        | [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Type -> Type -> Bool) -> [Type] -> [Type] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Type -> Bool
match [Type]
ps [Type]
argTs)
          -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
r
        | Bool
otherwise
          -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch Expr
expr ([Type] -> Type -> Type
TFun [Type]
ps Type
r) ([Type] -> Type -> Type
TFun [Type]
argTs Type
r))
      TVar Ident
_ -> ExceptT TypeError (State InferState) Type
freshTypeVar
      Type
_      -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch Expr
expr (Ident -> Type
TVar Ident
"_") Type
fty)

  BinOp BinOperator
op Expr
l Expr
r -> do
    Type
tl <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
l; Type
tr <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
r
    case BinOperator
op of
      BinOperator
Add -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
numBin  BinOperator
op Type
tl Type
tr; BinOperator
Sub -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
numBin  BinOperator
op Type
tl Type
tr
      BinOperator
Mul -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
numBin  BinOperator
op Type
tl Type
tr; BinOperator
Div -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
numBin  BinOperator
op Type
tl Type
tr; BinOperator
Mod -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
numBin  BinOperator
op Type
tl Type
tr
      BinOperator
Eq  -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
op Type
tl Type
tr; BinOperator
Neq -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
op Type
tl Type
tr
      BinOperator
Lt  -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
op Type
tl Type
tr; BinOperator
Le  -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
op Type
tl Type
tr
      BinOperator
Gt  -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
op Type
tl Type
tr; BinOperator
Ge  -> BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
op Type
tl Type
tr

  UnOp UnOperator
op Expr
e -> do
    Type
te <- FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv Expr
e
    case UnOperator
op of
      UnOperator
Neg | Type
te Type -> [Type] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Type
TInt,Type
TFloat] -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
te
          | Type -> Bool
isPoly Type
te              -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TInt   -- assume int, unifica depois
          | Bool
otherwise              -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (UnOperator -> Type -> TypeError
UnOpTypeErr UnOperator
op Type
te)
      UnOperator
Not | Type
te Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
TBool            -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
          | Type -> Bool
isPoly Type
te              -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
          | Bool
otherwise              -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (UnOperator -> Type -> TypeError
UnOpTypeErr UnOperator
op Type
te)
  List [Expr]
xs -> do
    [Type]
ts <- (Expr -> ExceptT TypeError (State InferState) Type)
-> [Expr] -> ExceptT TypeError (State InferState) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv) [Expr]
xs
    case [Type]
ts of
      []      -> ExceptT TypeError (State InferState) Type
freshTypeVar
      (Type
t:[Type]
ts') | (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
==Type
t) [Type]
ts' -> Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type
TList Type
t)
      [Type]
_       -> TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> Type -> TypeError
Mismatch Expr
expr (Type -> Type
TList ([Type] -> Type
forall a. [a] -> a
head [Type]
ts)) (Type -> Type
TList ([Type] -> Type
forall a. [a] -> a
last [Type]
ts)))

  Tuple [Expr]
xs -> [Type] -> Type
TTuple ([Type] -> Type)
-> ExceptT TypeError (State InferState) [Type]
-> ExceptT TypeError (State InferState) Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> ExceptT TypeError (State InferState) Type)
-> [Expr] -> ExceptT TypeError (State InferState) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (FuncEnv
-> Map Ident Type
-> Expr
-> ExceptT TypeError (State InferState) Type
inferExpr FuncEnv
fenv Map Ident Type
tenv) [Expr]
xs

-- | Match two types, allowing type variables.
match :: Type -> Type -> Bool
match :: Type -> Type -> Bool
match (TVar Ident
_) Type
_ = Bool
True
match Type
_ (TVar Ident
_) = Bool
True
match Type
a Type
b        = Type
a Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
b

resolve :: Type -> Type -> Type
resolve :: Type -> Type -> Type
resolve (TVar Ident
_) Type
t = Type
t
resolve Type
t       Type
_  = Type
t

-- | Infer types for pattern variables and return pattern type.
inferPattern :: Pattern -> Infer ([(Ident,Type)],Type)
inferPattern :: Pattern -> Infer ([(Ident, Type)], Type)
inferPattern = \case
  PVar Ident
x    -> do Type
tv <- ExceptT TypeError (State InferState) Type
freshTypeVar; ([(Ident, Type)], Type) -> Infer ([(Ident, Type)], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Ident
x,Type
tv)],Type
tv)
  Pattern
PWildcard -> ([(Ident, Type)], Type) -> Infer ([(Ident, Type)], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], Ident -> Type
TVar Ident
"_")
  PLit Literal
l    -> ([(Ident, Type)], Type) -> Infer ([(Ident, Type)], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], Literal -> Type
literalType Literal
l)
    -- padrão de lista
  PList [Pattern]
ps -> do
    [([(Ident, Type)], Type)]
xs <- (Pattern -> Infer ([(Ident, Type)], Type))
-> [Pattern]
-> ExceptT TypeError (State InferState) [([(Ident, Type)], Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> Infer ([(Ident, Type)], Type)
inferPattern [Pattern]
ps              -- xs :: [(vars,ty)]
    let ([[(Ident, Type)]]
vs,[Type]
ts) = [([(Ident, Type)], Type)] -> ([[(Ident, Type)]], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(Ident, Type)], Type)]
xs                  -- vs = variáveis, ts = tipos
    Type
elemTy <- case [Type]
ts of
                []      -> ExceptT TypeError (State InferState) Type
freshTypeVar     -- lista vazia → polimórfica
                (Type
t:[Type]
ts') -> (Type -> Type -> ExceptT TypeError (State InferState) Type)
-> Type -> [Type] -> ExceptT TypeError (State InferState) Type
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Type -> Type -> ExceptT TypeError (State InferState) Type
forall {m :: * -> *}.
MonadError TypeError m =>
Type -> Type -> m Type
unify Type
t [Type]
ts'
    ([(Ident, Type)], Type) -> Infer ([(Ident, Type)], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([[(Ident, Type)]] -> [(Ident, Type)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Ident, Type)]]
vs, Type -> Type
TList Type
elemTy)
    where
      -- une dois tipos, permitindo variáveis
      unify :: Type -> Type -> m Type
unify Type
acc Type
t
        | Type -> Type -> Bool
match Type
acc Type
t = Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Type -> Type
resolve Type
acc Type
t)
        | Bool
otherwise   = TypeError -> m Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
                           (Expr -> Type -> Type -> TypeError
Mismatch (Literal -> Expr
Lit (Ident -> Literal
LString Ident
"pattern"))
                                     (Type -> Type
TList Type
acc)
                                     (Type -> Type
TList Type
t))
  PTuple [Pattern]
ps -> do
    [([(Ident, Type)], Type)]
xs <- (Pattern -> Infer ([(Ident, Type)], Type))
-> [Pattern]
-> ExceptT TypeError (State InferState) [([(Ident, Type)], Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> Infer ([(Ident, Type)], Type)
inferPattern [Pattern]
ps
    let ([[(Ident, Type)]]
vs,[Type]
ts) = [([(Ident, Type)], Type)] -> ([[(Ident, Type)]], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(Ident, Type)], Type)]
xs
    ([(Ident, Type)], Type) -> Infer ([(Ident, Type)], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([[(Ident, Type)]] -> [(Ident, Type)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Ident, Type)]]
vs, [Type] -> Type
TTuple [Type]
ts)

-- | Determine the type of a literal.
literalType :: Literal -> Type
literalType :: Literal -> Type
literalType = \case
  LInt Int
_    -> Type
TInt
  LFloat Double
_  -> Type
TFloat
  LBool Bool
_   -> Type
TBool
  LChar Char
_   -> Type
TChar
  LString Ident
_ -> Type
TString

-- | Type-check numeric binary operators.
numBin, boolBin, compBin :: BinOperator -> Type -> Type -> Infer Type
numBin :: BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
numBin BinOperator
_ Type
TInt   Type
TInt   = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TInt
numBin BinOperator
_ Type
TFloat Type
TFloat = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TFloat
numBin BinOperator
op Type
a Type
b
  | Type -> Bool
isPoly Type
a Bool -> Bool -> Bool
|| Type -> Bool
isPoly Type
b = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TInt
  | Bool
otherwise            = TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (BinOperator -> Type -> Type -> TypeError
BinOpTypeErr BinOperator
op Type
a Type
b)

-- | Type-check boolean binary operators.
boolBin :: BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
boolBin BinOperator
_ Type
TBool Type
TBool = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
boolBin BinOperator
op Type
a Type
b
  | Type -> Bool
isPoly Type
a Bool -> Bool -> Bool
|| Type -> Bool
isPoly Type
b = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
  | Bool
otherwise            = TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (BinOperator -> Type -> Type -> TypeError
BinOpTypeErr BinOperator
op Type
a Type
b)

-- | Type-check comparison binary operators.
compBin :: BinOperator
-> Type -> Type -> ExceptT TypeError (State InferState) Type
compBin BinOperator
_ Type
a Type
b
  | Type
aType -> Type -> Bool
forall a. Eq a => a -> a -> Bool
==Type
b Bool -> Bool -> Bool
&& Type
a Type -> [Type] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Type
TInt,Type
TFloat,Type
TChar,Type
TString] = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
compBin BinOperator
_ Type
a Type
b
  | Type -> Bool
isPoly Type
a Bool -> Bool -> Bool
|| Type -> Bool
isPoly Type
b = Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
TBool
compBin BinOperator
op Type
a Type
b = TypeError -> ExceptT TypeError (State InferState) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (BinOperator -> Type -> Type -> TypeError
BinOpTypeErr BinOperator
op Type
a Type
b)

-- | Ensure an expression has boolean type.
ensureBool :: Expr -> Type -> Infer ()
ensureBool :: Expr -> Type -> ExceptT TypeError (State InferState) ()
ensureBool Expr
_ Type
TBool = () -> ExceptT TypeError (State InferState) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ensureBool Expr
e Type
t     = TypeError -> ExceptT TypeError (State InferState) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Expr -> Type -> TypeError
CondNotBool Expr
e Type
t)

-- | Generate a fresh type variable.
freshTypeVar :: Infer Type
freshTypeVar :: ExceptT TypeError (State InferState) Type
freshTypeVar = do
  InferState
s <- ExceptT TypeError (State InferState) InferState
forall s (m :: * -> *). MonadState s m => m s
get
  let n :: Int
n = InferState -> Int
count InferState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  InferState -> ExceptT TypeError (State InferState) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put InferState
s { count :: Int
count = Int
n }
  Type -> ExceptT TypeError (State InferState) Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> Type
TVar (Ident
"t" Ident -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Ident
forall a. Show a => a -> Ident
show Int
n))

-- | Check if a type is polymorphic (a type variable).
isPoly :: Type -> Bool
isPoly :: Type -> Bool
isPoly (TVar Ident
_) = Bool
True
isPoly Type
_        = Bool
False

-- ======================================================
-- 7) Export API
-- ======================================================

-- | Run full semantic and type checks on a program.
checkAll :: Program -> [Error]
checkAll :: Program -> [Error]
checkAll Program
p = Program -> [Error]
semanticCheck Program
p [Error] -> [Error] -> [Error]
forall a. [a] -> [a] -> [a]
++ Program -> [Error]
checkProgram Program
p