Skip to content

Commit

Permalink
Printing LLVM
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jan 6, 2025
1 parent f534427 commit 7bf7ad9
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 48 deletions.
30 changes: 30 additions & 0 deletions src/lib/LLVMFFI.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
-- Copyright 2025 Google LLC
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

module LLVMFFI (LLVMContext, initializeLLVM, compileLLVM, getFunctionPtr,
callEntryFun) where

import Data.Int
import Util (BString)

foreign import ccall "doit_cpp" doit_cpp :: Int64 -> IO Int64

type FunctionPtr = ()
type LLVMContext = ()
type DataPtr = ()
type DataListPtr = ()

initializeLLVM :: IO LLVMContext
initializeLLVM = return undefined

compileLLVM :: LLVMContext -> BString -> IO ()
compileLLVM _ _ = return undefined

getFunctionPtr :: LLVMContext -> BString -> IO FunctionPtr
getFunctionPtr _ _ = return undefined

callEntryFun :: FunctionPtr -> [DataPtr] -> IO DataPtr
callEntryFun _ _ = return undefined
12 changes: 8 additions & 4 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
{-# LANGUAGE NoFieldSelectors #-}

module PPrint (
Pretty (..), indent, emitLine, hcat, hlist, pprint, app,
Pretty (..), indent, emitLine, hcat, hlist, pprint, app, pprintStr,
(<+>), BSBuilder, forceOneLine) where

import Data.ByteString.Internal (w2c)
import Data.Int
import Data.Word
import Data.List (intersperse)
Expand All @@ -22,6 +23,12 @@ pprint :: Pretty a => a -> BString
pprint x = runPrinter $ prLines x
{-# SCC pprint #-}

pprintStr :: Pretty a => a -> String
pprintStr x = bs2str $ pprint x

bs2str :: BString -> String
bs2str s = map w2c $ BS.unpack s

-- === printing doc ===

type BString = BS.ByteString
Expand All @@ -32,11 +39,8 @@ data PrinterState = PrinterState {indent :: Indent, curString :: BS.Builder }
newtype PrinterM a = PrinterM { inner :: State PrinterState a }
deriving (Functor, Applicative, Monad)

-- Instances should define either `pr` (if they're expected to be one-liners
-- most of the time) or `prLines`.
class Pretty a where
pr :: a -> BSBuilder
pr x = forceOneLine $ prLines x

prLines :: a -> PrinterM ()
prLines x = emitLine $ pr x
Expand Down
137 changes: 137 additions & 0 deletions src/lib/ToLLVM.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
-- Copyright 2025 Google LLC
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

{-# LANGUAGE NoFieldSelectors #-}

module ToLLVM where

import Name
import Control.Monad
import Control.Monad.State.Strict hiding (state)
import Data.String (fromString)
import qualified Data.ByteString as BS
import qualified Types.LLVM as L
import Types.Simple
import Types.Primitives
import PPrint

import Debug.Trace
import QueryTypePure
import Util

-- === entrypoint ===

toLLVMEntryFun :: Monad m => L.Name -> TopLamExpr -> m L.Function
toLLVMEntryFun fname fun = do
finalState <- runTranslateM do
toLLVMEntryFun' fun
startNewBlock $ L.Name "__unused__"
let blocks = reverse finalState.basicBlocks
return $ L.Function fname [] blocks

-- === monad for the translation ===

data TranslateState i = TranslateState
{ basicBlocks :: [L.BasicBlock] -- reverse order
, instructions :: [L.Decl] -- reverse order
, curBlockName :: L.Name
, nameGen :: Int
, subst :: TranslateSubst i}
type TranslateSubst i = Subst (LiftE L.Operand) i VoidS

newtype TranslateM (i::S) (a:: *) =
TranslateM { inner :: State (TranslateState i) a }
deriving (Functor, Applicative, Monad)

runTranslateM :: Monad m => TranslateM VoidS a -> m (TranslateState VoidS)
runTranslateM cont = do
let initState = TranslateState [] [] (L.Name "__entry__") 0 voidSubst
return $ execState cont.inner initState

emitInstr :: L.Type -> L.Instruction -> TranslateM i L.Operand
emitInstr resultTy instr = do
v <- newLName ""
let decl = (Just v, resultTy, instr)
TranslateM $ modify \s -> s {instructions = decl : s.instructions}
return $ L.Operand (L.LocalOcc v) resultTy

emitStatement :: L.Instruction -> TranslateM i ()
emitStatement instr = do
let decl = (Nothing, L.VoidType, instr)
TranslateM $ modify \s -> s {instructions = decl : s.instructions}

extendEnv :: NameBinder i i' -> L.Operand -> TranslateM i' a -> TranslateM i a
extendEnv b x cont = TranslateM do
prevState <- get
let subst' = prevState.subst <>> (b @> LiftE x)
let (ans, newState) = runState (cont.inner) $ updateSubst prevState subst'
put $ updateSubst newState prevState.subst
return ans

lookupEnv :: Name i -> TranslateM i L.Operand
lookupEnv v = TranslateM do
env <- gets (.subst)
let LiftE x = env ! v
return x

updateSubst :: TranslateState i -> TranslateSubst i' -> TranslateState i'
updateSubst (TranslateState a b c d _) subst = TranslateState a b c d subst

newLName :: BString -> TranslateM i L.Name
newLName hint = TranslateM do
c <- gets (.nameGen)
modify \s -> s {nameGen = s.nameGen + 1}
return $ L.Name $ hint <> "_" <> fromString (show c)

startNewBlock :: L.Name -> TranslateM i ()
startNewBlock blockName = TranslateM $ modify \state -> do
let newBlock = L.BasicBlock state.curBlockName (reverse state.instructions)
state {
basicBlocks = newBlock : state.basicBlocks,
curBlockName = blockName,
instructions = []}

-- === translation itself ===

toLLVMEntryFun' :: TopLamExpr -> TranslateM VoidS ()
toLLVMEntryFun' (TopLamExpr (Abs Empty body)) = do
trExpr body
return ()

trExpr :: Expr i -> TranslateM i L.Operand
trExpr = \case
Block resultTy block -> trBlock block
PrimOp resultTy op -> do
resultTy' <- trType resultTy
op' <- forM op trAtom
trPrimOp resultTy' op'

trType :: Type i -> TranslateM i L.Type
trType = \case
BaseType b -> return $ L.BaseType b
ProdType [] -> return L.VoidType
t -> error $ "not implemented: " ++ pprintStr t

trAtom :: Atom i -> TranslateM i L.Operand
trAtom = \case
Var v _ -> do
val <- lookupEnv v
return val
Lit v -> return $ L.Operand (L.Lit v) (L.BaseType (litType v))

trBlock :: Block i -> TranslateM i L.Operand
trBlock (Abs decls result) = case decls of
Empty -> trExpr result
Nest (Let b expr) rest -> do
val <- trExpr expr
extendEnv b val $ trBlock $ Abs rest result

trPrimOp :: L.Type -> PrimOp L.Operand -> TranslateM i L.Operand
trPrimOp resultTy op = case op of
BinOp b x y -> case b of
FAdd -> emitInstr resultTy $ L.FAdd x y
MiscOp op' -> case op' of
DebugPrintInt x -> undefined
4 changes: 2 additions & 2 deletions src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import PPrint
import Simplify
import LLVMFFI
import ToLLVM
import Types.LLVM
import qualified Types.LLVM as L
import Types.Complicated
import Types.Primitives
import Types.Source hiding (CTopDecl)
Expand Down Expand Up @@ -104,7 +104,7 @@ execUDecl decl = do
CTopLet Nothing expr <- checkPass TypePass $ inferTopUDecl renamed
simpFun <- simplifyTopFun (exprAsNullaryFun expr)
logPass SimpPass simpFun
let tempFunName = "main" -- TODO: need to get a name
let tempFunName = L.Name "main" -- TODO: need to get a name
llvmContext <- TopperM $ asks topperLLVMContext
llvmFun <- toLLVMEntryFun tempFunName simpFun
logPass LLVMPass llvmFun
Expand Down
101 changes: 101 additions & 0 deletions src/lib/Types/LLVM.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
-- Copyright 2025 Google LLC
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

{-# LANGUAGE DuplicateRecordFields #-}

module Types.LLVM where

import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BS

import qualified Types.Primitives as P
import PPrint
import Util (bs2str)

-- this string doesn't include the `@` or `%` prefixes
newtype Name = Name { val :: ByteString }
type Binder = (Name, Type)

data Module = Module { functions :: [Function] }

data Function = Function
{ name :: Name
, params :: [Binder]
, body :: [BasicBlock] }

data BasicBlock = BasicBlock
{ name :: Name
, instructions :: [Decl]}

type Decl = (Maybe Name, Type, Instruction)
data Instruction =
FAdd Operand Operand
| Return Operand

data Operand = Operand { val :: UntypedOperand, ty :: Type }
data UntypedOperand =
LocalOcc Name
| Lit P.LitVal

data Type =
BaseType P.BaseType
| VoidType

-- === LLVM printing ===

-- This is load-bearing! We have to generate correct LLVM textual representation.

instance Pretty Function where
prLines (Function name [] body) = do
emitLine $ "define i32" <+> prTopName name <> "() {"
forM_ body \block -> do
emitLine ""
prLines block
emitLine "}"

prTopName :: Name -> BS.Builder
prTopName name = "@" <> BS.byteString name.val

prLocalName :: Name -> BS.Builder
prLocalName name = "%" <> BS.byteString name.val

prDecl :: Decl -> BS.Builder
prDecl (Just v, resultTy, instr) = prLocalName v <> " = " <> prInstr resultTy instr

prInstr :: Type -> Instruction -> BS.Builder
prInstr resultTy = \case
FAdd x y -> "fadd " <> pr resultTy <+> pr x.val <> ", " <> pr y.val

instance Pretty BasicBlock where
prLines (BasicBlock name decls) = do
emitLine $ pr name <> ":"
indent do
forM_ decls \decl -> emitLine $ prDecl decl

instance Pretty Name where
pr name = BS.byteString name.val

instance Pretty UntypedOperand where
pr = \case
LocalOcc v -> prLocalName v
Lit v -> pr v

instance Pretty Type where
pr = \case
BaseType (P.Scalar b) -> case b of
P.Float32Type -> "f32"
VoidType -> "void"


-- instance LLVMSer Operand where
-- lpr x = cat [lpr (getType x), ", ", printOperandWithoutType x]

-- instance Pretty Type where
-- pr = undefined


10 changes: 9 additions & 1 deletion src/lib/Types/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,15 @@ instance GenericE Type where
DepPairTy p -> Case4 $ p
TabPi t -> Case5 $ t

instance Pretty (Type n)
instance Pretty (Type n) where
pr = \case
BaseType b -> pr b
ProdType _ -> undefined
SumType _ -> undefined
RefType _ -> undefined
DepPairTy _ -> undefined
TabPi _ -> undefined

instance SinkableE Type
instance HoistableE Type
instance RenameE Type
Expand Down
6 changes: 3 additions & 3 deletions src/lib/Types/Source.hs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ data PrintBackend =

data OutFormat = Printed (Maybe PrintBackend) | RenderHtml deriving (Show, Eq, Generic)

data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | JitPass | LLVMPass
data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | LLVMPass
| LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval | LowerOptPass | LowerPass
| ResultPass | JaxprAndHLO | EarlyOptPass | OptPass | VectPass | OccAnalysisPass
| InlinePass
Expand All @@ -597,13 +597,13 @@ data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | JitPass | L
instance Show PassName where
show p = case p of
Parse -> "parse" ; RenamePass -> "rename"; TypePass -> "typed"
SimpPass -> "simp" ; ImpPass -> "imp" ; JitPass -> "llvm"
SimpPass -> "simp" ; ImpPass -> "imp"
LLVMOpt -> "llvmopt" ; AsmPass -> "asm"
JAXPass -> "jax" ; JAXSimpPass -> "jsimp"; ResultPass -> "result"
LLVMEval -> "llvmeval" ; JaxprAndHLO -> "jaxprhlo";
LowerOptPass -> "lower-opt"; LowerPass -> "lower"
EarlyOptPass -> "early-opt"; OptPass -> "opt"; OccAnalysisPass -> "occ-analysis"
VectPass -> "vect"; InlinePass -> "inline"
VectPass -> "vect"; InlinePass -> "inline"; LLVMPass -> "llvm"

data EnvQuery =
DumpSubst
Expand Down
Loading

0 comments on commit 7bf7ad9

Please sign in to comment.