|
| 1 | +{-# LANGUAGE TemplateHaskell #-} |
| 2 | + |
| 3 | +module Main where |
| 4 | + |
| 5 | +import Data.Functor ((<$>)) |
| 6 | +import Data.Maybe (fromJust) |
| 7 | +import System.FilePath (replaceExtension) |
| 8 | +import Control.Lens hiding (op) |
| 9 | +import Control.Applicative ((<*>)) |
| 10 | +import Control.Monad.State.Lazy (runStateT, StateT) |
| 11 | +import Control.Monad.Except (runExceptT, ExceptT) |
| 12 | +import System.Environment (getArgs) |
| 13 | +import LLVM.General.PrettyPrint (showPretty) |
| 14 | +import LLVM.General.Analysis (verify) |
| 15 | +import LLVM.General.PassManager (withPassManager, defaultCuratedPassSetSpec, optLevel, runPassManager) |
| 16 | +import LLVM.General.Target (withDefaultTargetMachine) |
| 17 | +import LLVM.General.Context (withContext) |
| 18 | +import LLVM.General.Module (withModuleFromLLVMAssembly, moduleAST, File(File)) |
| 19 | +import LLVM.General.AST.Instruction (Named(..), Instruction(..)) |
| 20 | +import LLVM.General.AST.Attribute (ParameterAttribute) |
| 21 | +import LLVM.General.AST.AddrSpace (AddrSpace(..)) |
| 22 | +import qualified Data.Map as M |
| 23 | +import qualified LLVM.General.AST as AST |
| 24 | +import qualified LLVM.General.Module as M |
| 25 | +import qualified LLVM.General.AST.Global as G |
| 26 | +import qualified LLVM.General.AST.Constant as C |
| 27 | +import qualified LLVM.General.AST.Type as T |
| 28 | +import qualified LLVM.General.AST.CallingConvention as CallingConvention |
| 29 | + |
| 30 | +data SourceLoc = SourceLoc Int Int FilePath deriving (Eq, Ord) |
| 31 | + |
| 32 | +type NumberedMetadata = M.Map AST.MetadataNodeID [Maybe AST.Operand] |
| 33 | +data ComputationState = ComputationState |
| 34 | + { _globalCounters :: M.Map SourceLoc AST.Operand |
| 35 | + , _introducedGlobals :: [AST.Definition] |
| 36 | + , _fresh :: Int |
| 37 | + , _numberedMetadata :: NumberedMetadata |
| 38 | + } |
| 39 | + |
| 40 | +makeLenses ''ComputationState |
| 41 | + |
| 42 | +state :: NumberedMetadata -> ComputationState |
| 43 | +state = ComputationState M.empty [] 0 |
| 44 | + |
| 45 | +type BlockMonad a = StateT ComputationState Identity a |
| 46 | + |
| 47 | +main :: IO () |
| 48 | +main = do |
| 49 | + cacheSource : target : _ <- getArgs |
| 50 | + cacheDefs <- AST.moduleDefinitions <$> readAssembly cacheSource |
| 51 | + parsed <- readAssembly target |
| 52 | + putStrLn $ showPretty cacheDefs |
| 53 | + let (inj, st) = runBlockMonad (state md) . mapM inject $ AST.moduleDefinitions parsed |
| 54 | + newDefs = _introducedGlobals st ++ cacheDefs ++ (injectPrinting (_globalCounters st) <$> inj) |
| 55 | + altered = parsed { AST.moduleDefinitions = newDefs } |
| 56 | + md = M.fromList [ (i, ops) | AST.MetadataNodeDefinition i ops <- AST.moduleDefinitions parsed ] |
| 57 | + asGeneralModule altered (\m -> do |
| 58 | + verifyResult <- runExceptT $ verify m |
| 59 | + case verifyResult of |
| 60 | + Left mess -> putStrLn $ "Verify error: " ++ mess |
| 61 | + Right _ -> do |
| 62 | + putStrLn "result: " |
| 63 | + withPassManager (defaultCuratedPassSetSpec {optLevel = Just 3}) $ \pm -> |
| 64 | + runPassManager pm m |
| 65 | + writeObjectFile (replaceExtension target "o") m |
| 66 | + printModule m |
| 67 | + ) |
| 68 | + |
| 69 | +injectPrinting :: M.Map SourceLoc AST.Operand -> AST.Definition -> AST.Definition |
| 70 | +injectPrinting locs = inner |
| 71 | + where |
| 72 | + inner (AST.GlobalDefinition f@G.Function{G.basicBlocks = bs, G.name = AST.Name "main"}) = AST.GlobalDefinition $ f {G.basicBlocks = map attachPrinting bs} |
| 73 | + inner d = d |
| 74 | + attachPrinting (AST.BasicBlock n i r@(Do AST.Ret{})) = AST.BasicBlock n (i ++ printIs) r |
| 75 | + attachPrinting b = b |
| 76 | + printIs = printI <$> M.toList locs |
| 77 | + printI (SourceLoc l c _, op) = Do $ Call False CallingConvention.C [] func [(cInt l, []), (cInt c, []), (op, [])] [] [] |
| 78 | + cInt = AST.ConstantOperand . C.Int 64 . toInteger |
| 79 | + func = Right . AST.ConstantOperand $ C.GlobalReference t (AST.Name "__printSimCacheData") |
| 80 | + t = T.FunctionType T.VoidType [T.i64, T.i64, T.PointerType counterType (AddrSpace 0)] False |
| 81 | + |
| 82 | +inject :: AST.Definition -> BlockMonad AST.Definition |
| 83 | +inject (AST.GlobalDefinition f@G.Function{G.basicBlocks = blocks}) = do |
| 84 | + newBlocks <- mapM iBlock blocks |
| 85 | + return . AST.GlobalDefinition $ f {G.basicBlocks = newBlocks} |
| 86 | + where |
| 87 | + iBlock (AST.BasicBlock n is t) = AST.BasicBlock n <$> recurse is <*> return t |
| 88 | + recurse [] = return [] |
| 89 | + recurse (i:is) = case unName i of |
| 90 | + Load{address = ptrOp, metadata = md} -> access ptrOp md $ (i:) <$> recurse is |
| 91 | + Store{address = ptrOp, metadata = md} -> access ptrOp md $ (i:) <$> recurse is |
| 92 | + _ -> (i:) <$> recurse is |
| 93 | + unName (Do i) = i |
| 94 | + unName (_ := i) = i |
| 95 | + |
| 96 | +inject d = return d |
| 97 | + |
| 98 | +access :: AST.Operand -> [(String, AST.MetadataNode)] -> BlockMonad [Named Instruction] -> BlockMonad [Named Instruction] |
| 99 | +access op md cont = do |
| 100 | + counterOp <- getLoc md >>= getCounter |
| 101 | + (ptrOp, castInstr) <- cast |
| 102 | + callInstr <- callCache [(ptrOp, []), (counterOp, [])] |
| 103 | + (castInstr :) . (callInstr :) <$> cont |
| 104 | + where |
| 105 | + cast = do |
| 106 | + nInt <- fresh <<+= 1 |
| 107 | + let name = AST.Name $ "bitcastthing" ++ show nInt |
| 108 | + return (AST.LocalReference t name, name := BitCast op t []) |
| 109 | + t = T.PointerType T.i8 (AddrSpace 0) |
| 110 | + |
| 111 | +getCounter :: SourceLoc -> BlockMonad AST.Operand |
| 112 | +getCounter l = use (globalCounters . at l) >>= \mo -> case mo of |
| 113 | + Just o -> return o |
| 114 | + Nothing -> do |
| 115 | + nInt <- fresh <<+= 1 |
| 116 | + let g = G.globalVariableDefaults {G.name = name, G.type' = counterType, G.initializer = Just $ C.Struct Nothing False [C.Int 64 0, C.Int 64 0]} |
| 117 | + name = AST.Name $ "globForLoc" ++ show nInt |
| 118 | + op = AST.ConstantOperand $ C.GlobalReference (T.PointerType counterType (AddrSpace 0)) name |
| 119 | + introducedGlobals %= (AST.GlobalDefinition g:) |
| 120 | + globalCounters . at l ?= op |
| 121 | + return op |
| 122 | + |
| 123 | +callCache :: [(AST.Operand, [ParameterAttribute])] -> BlockMonad (Named Instruction) |
| 124 | +callCache params = return . Do $ Call False CallingConvention.C [] func params [] [] |
| 125 | + where |
| 126 | + func = Right . AST.ConstantOperand $ C.GlobalReference t (AST.Name "__memory_blub") |
| 127 | + t = T.FunctionType T.VoidType [T.PointerType T.i8 (AddrSpace 0), T.PointerType counterType (AddrSpace 0)] False |
| 128 | + |
| 129 | +counterType :: T.Type |
| 130 | +counterType = T.StructureType False [T.i64, T.i64] |
| 131 | + |
| 132 | +getLoc :: [(String, AST.MetadataNode)] -> BlockMonad SourceLoc |
| 133 | +getLoc md = case lookup "dbg" md of |
| 134 | + Nothing -> error $ "Couldn't find dbg in " ++ show md |
| 135 | + Just (AST.MetadataNode l) -> inner l |
| 136 | + Just (AST.MetadataNodeReference i) -> fromJust <$> use (numberedMetadata . at i) >>= inner |
| 137 | + where |
| 138 | + inner :: [Maybe AST.Operand] -> BlockMonad SourceLoc |
| 139 | + inner (l : c : Just scope : _) = SourceLoc (getVal l) (getVal c) <$> case scope of |
| 140 | + AST.MetadataNodeOperand (AST.MetadataNodeReference r) -> getStr . head <$> |
| 141 | + (readRef r >>= readRef . getRef . (!! 1)) |
| 142 | + getVal (Just (AST.ConstantOperand (C.Int _ v))) = fromInteger v |
| 143 | + getStr (Just (AST.MetadataStringOperand s)) = s |
| 144 | + getRef (Just (AST.MetadataNodeOperand (AST.MetadataNodeReference r))) = r |
| 145 | + readRef r = fromJust <$> use (numberedMetadata . at r) |
| 146 | + |
| 147 | +runBlockMonad :: ComputationState -> BlockMonad a -> (a, ComputationState) |
| 148 | +runBlockMonad st m = runIdentity $ runStateT m st |
| 149 | + |
| 150 | +readAssembly :: FilePath -> IO AST.Module |
| 151 | +readAssembly path = withContext $ \c -> |
| 152 | + failIO $ withModuleFromLLVMAssembly c (File path) moduleAST |
| 153 | + |
| 154 | +failIO :: Show err => ExceptT err IO a -> IO a |
| 155 | +failIO e = runExceptT e >>= \r -> case r of |
| 156 | + Left err -> fail $ show err |
| 157 | + Right a -> return a |
| 158 | + |
| 159 | +boolean :: a -> a -> Bool -> a |
| 160 | +boolean a _ True = a |
| 161 | +boolean _ a False = a |
| 162 | + |
| 163 | +writeObjectFile :: FilePath -> M.Module -> IO () |
| 164 | +writeObjectFile path m = failIO . withDefaultTargetMachine $ \mac -> failIO $ M.writeObjectToFile mac (M.File path) m |
| 165 | + |
| 166 | +asGeneralModule :: AST.Module -> (M.Module -> IO a) -> IO a |
| 167 | +asGeneralModule m monad = do |
| 168 | + eRes <- withContext $ \context -> |
| 169 | + runExceptT . M.withModuleFromAST context m $ monad |
| 170 | + case eRes of |
| 171 | + Left mess -> fail mess |
| 172 | + Right res -> return res |
| 173 | + |
| 174 | +printModule :: M.Module -> IO () |
| 175 | +printModule m = M.moduleLLVMAssembly m >>= putStrLn |
0 commit comments