From 499e4e42d462f436a5267ddf0c2f73d5741a8248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karl-Filip=20Fax=C3=A9n?= Date: Tue, 16 Mar 2021 17:12:25 +0100 Subject: [PATCH] Make resnet18 work again. --- src/Feldspar/Onnx/Operators.hs | 36 ++++++++++++---------------------- src/Onnx/OnnxToFeld.hs | 29 +++++++++++++++------------ 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/src/Feldspar/Onnx/Operators.hs b/src/Feldspar/Onnx/Operators.hs index a1557e0..8615266 100644 --- a/src/Feldspar/Onnx/Operators.hs +++ b/src/Feldspar/Onnx/Operators.hs @@ -263,36 +263,19 @@ onnxBatchNormalization attrs xs gamma beta mean var = ys epsilon = value $ P.realToFrac $ getAttr attrs aaFloat 1e-5 "epsilon" -- | Flatten a tensor to a matrix --- We use Push vectors for flattening even though fusion is lost since flattening based on +-- We use Manifest vectors for flattening even though fusion is lost since flattening based on -- Pull vectors leads to index expressions containing division and modulus that have a greater -- performance impact than the loss of fusion. -onnxFlatten :: Syntax a => Attrs -> Pull sh a -> Pull DIM2 a -onnxFlatten attrs vec = toPull $ store $ flatPush (P.fromIntegral $ getAttr attrs aaInt 1 "axis") $ toPush vec - -onnxFlatten' :: (Storable vec, Syntax a) => Attrs -> vec a -> Pull DIM2 a -onnxFlatten' attrs xs = toPull $ flatMan d $ store xs +onnxFlatten :: (Storable vec, Syntax a) => Attrs -> vec a -> Manifest DIM2 a +onnxFlatten attrs xs = flatMan d $ store xs where d = P.fromIntegral $ getAttr attrs aaInt 1 "axis" +-- | Flatten a Manifest vector to two dimensions by changing its extent flatMan :: Int -> Manifest sh a -> Manifest DIM2 a flatMan d (Manifest arr sh) = Manifest arr sh' where sh' = Z :. P.product ls :. P.product rs (ls,rs) = takeDropShape d sh --- | Flattening a Push vector to two dimensions -flatPush :: forall sh a . Int -> Push sh a -> Push DIM2 a -flatPush i (Push ixf ext) = Push ixf' $ Z :. P.product ls :. P.product rs - where ixf' :: PushK DIM2 a - ixf' wf = ixf (\ sh d -> let (ils,irs) = takeDropShape i sh - in wf (Z :. idxExp ls ils :. idxExp rs irs) d) - (ls,rs) = takeDropShape i ext - --- | Linearizing an index where both intex and extent are represented as lists -idxExp :: [Data Length] -> [Data Length] -> Data Length -idxExp ext idx = f (P.reverse ext) (P.reverse idx) - where f (n:ns) (i:is) = f ns is * n + i - f [] [] = 0 - f _ _ = P.error "Operators.idxExp: extent and index differ in length" - -- | Split a shape takeDropShape :: Int -> Shape sh -> ([Data Length], [Data Length]) takeDropShape i sh = P.splitAt j es @@ -301,8 +284,8 @@ takeDropShape i sh = P.splitAt j es -- | Matrix multiplication of two dimensional temsors onnxGemm :: (RealFloat a, Numeric a, Pully vec, VecShape vec ~ DIM2, Pully vec2, - UnionShape DIM2 (VecShape vec2) ~ DIM2, Storable vec) - => Attrs -> vec (Data a) -> vec (Data a) -> vec2 (Data a) -> DPull DIM2 a + UnionShape DIM2 (VecShape vec2) ~ DIM2, Storable vec, Pully vec', Storable vec', VecShape vec' ~ DIM2) + => Attrs -> vec (Data a) -> vec' (Data a) -> vec2 (Data a) -> DPull DIM2 a onnxGemm attrs vA vB vC = bcZipWith (+) (mmT vAT vBnT) $ toPull vC where vA' = if alpha P.== 1.0 then toPull $ store vA else map (* value alpha) $ toPull $ store vA vAT = if transA P.== 1 then transpose vA' else vA' @@ -590,3 +573,10 @@ setSizeManifest3 :: Length -> Length -> Length -> Manifest DIM3 a -> Manifest DI setSizeManifest3 s1 s2 s3 (Manifest arr (Z :. e1 :. e2 :. e3)) = Manifest (cap (singletonRange (s1*s2*s3) :> universal) arr) (Z :. cap (singletonRange s1) e1 :. cap (singletonRange s2) e2 :. cap (singletonRange s3) e3) + +-- | Add range info to four dimensinal Manifest vector +setSizeManifest4 :: Length -> Length -> Length -> Length -> Manifest DIM4 a -> Manifest DIM4 a +setSizeManifest4 s1 s2 s3 s4 (Manifest arr (Z :. e1 :. e2 :. e3 :. e4)) + = Manifest (cap (singletonRange (s1*s2*s3*s4) :> universal) arr) + (Z :. cap (singletonRange s1) e1 :. cap (singletonRange s2) e2 :. cap (singletonRange s3) e3 + :. cap (singletonRange s4) e4) diff --git a/src/Onnx/OnnxToFeld.hs b/src/Onnx/OnnxToFeld.hs index 66c1740..3ea3ac6 100644 --- a/src/Onnx/OnnxToFeld.hs +++ b/src/Onnx/OnnxToFeld.hs @@ -203,7 +203,7 @@ mkProgramFile gr initGroups sInits inputs multiUses shapes useNative where name = mangle $ fromJust $ G.name gr params = if null initGroups then inputPs else "(weights :: WeightRec)" : inputPs inputPs = map mkParam inputs - accesses = map mkAccess $ concat initGroups + accesses = map (mkAccess useNative) $ concat initGroups tEnv = M.fromList [(fromJust $ TP.name t, t) | t <- D.toList $ G.initializer gr] mkParam ti = "(" <> mangle (vipName ti) <> "' :: " <> shTy (vipType ti) <> ")" shTy (d,t) = "Manifest DIM" <> show d <> " (Data " <> showElemT t <> ")" @@ -214,7 +214,7 @@ mkProgramFile gr initGroups sInits inputs multiUses shapes useNative -- | Make a possibly shape contraining input binding mkInputCap :: V.ValueInfoProto -> Maybe String -> L.ByteString -mkInputCap vi msh = mangle (vipName vi) <> " = toPull $ " <> go msh <> mangle (vipName vi) <> "'" +mkInputCap vi msh = mangle (vipName vi) <> " = " <> go msh <> mangle (vipName vi) <> "'" where go Nothing = "" go (Just s) = "setSizeManifest" <> show (length sh) <> " " <> L.unwords (map show sh) <> " " where sh = read $ "[" <> s <> "]" :: [Int] @@ -251,10 +251,11 @@ shapeToDim :: TP.TensorProto -> Int64 shapeToDim p = TP.dims p `D.index` 0 -- | Read from the weight record -mkAccess :: TensorInfo -> L.ByteString -mkAccess ti = vname <> " = " <> setS <> " $ sel (Proxy @" <> show (tiField ti) <> ") weights ! (Z :. " <> show (tiIdx ti) <> ")" +mkAccess :: Bool -> TensorInfo -> L.ByteString +mkAccess useSize ti = vname <> " = " <> setS <> "sel (Proxy @" <> show (tiField ti) <> ") weights ! (Z :. " <> show (tiIdx ti) <> ")" where vname = mangle $ tiName ti - setS = "setSizePull" <> show (length dims) <> L.concat (map (\d -> " " <> show d) dims) + setS | useSize = "setSizePull" <> show (length dims) <> L.concat (map (\d -> " " <> show d) dims) <> " $ " + | otherwise = "" dims = D.toList $ tiDims ti -- | Initialize a small tensor @@ -334,7 +335,7 @@ mkMainFile bname hf weightFile weightRecTC outputs inputs noWeightRec useNative , " }" , "" ] - ++ wread ++ concat (zipWith (mkArgRead bname useNative) inputs [1..]) ++ + ++ wread ++ concat (zipWith (mkArgRead ebname noWeightRec useNative) inputs [1..]) ++ [ " " <> ot <> " " <> ov <> " = {0};" , "" , " " <> functionName <> "(" <> warg <> inArgs <> ", &" <> ov <> ");" @@ -346,8 +347,9 @@ mkMainFile bname hf weightFile weightRecTC outputs inputs noWeightRec useNative ] (ov, _, oCode) = mkOutput useNative outputs inArgs = L.intercalate ", " ["&" <> mangle (vipName v) | v <- inputs] - functionName = fromString $ encodeFunctionName bname - ot = argumentType bname $ length inputs + 1 + functionName = fromString ebname + ot = argumentType ebname noWeightRec $ length inputs + 1 + ebname = encodeFunctionName bname wread | noWeightRec = [] | otherwise = [" weight_rec_t * w = read_constants(\"" <> fromString weightFile <> "\");" , "" @@ -365,15 +367,15 @@ mkMainFile bname hf weightFile weightRecTC outputs inputs noWeightRec useNative ] -- | Generate code to read the argument tensors from file -mkArgRead :: FilePath -> Bool -> V.ValueInfoProto -> Int -> [L.ByteString] -mkArgRead bname useNative vip i +mkArgRead :: FilePath -> Bool -> Bool -> V.ValueInfoProto -> Int -> [L.ByteString] +mkArgRead bname noWeightRec useNative vip i = [ " FILE* " <> fname <> " = fopen(argv[" <> show i <> "], \"r\");" , " if (" <> fname <> " == NULL) {" , " fprintf(stderr, \"Could not open %s for reading.\\n\", argv[" <> show i <> "]);" , " exit(1);" , " }" , "" - , " " <> argumentType bname i <> " " <> vname <> ";" + , " " <> argumentType bname noWeightRec i <> " " <> vname <> ";" , "" , allocReadTensor useNative fname vname n elemT , " fclose(" <> fname <> ");" @@ -384,8 +386,9 @@ mkArgRead bname useNative vip i (n, elemT) = vipType vip -- | Argument type name -argumentType :: FilePath -> Int -> L.ByteString -argumentType bname i = fromString $ "arg_" <> show i <> "_" <> bname <> "_t" +argumentType :: FilePath -> Bool -> Int -> L.ByteString +argumentType bname noWeightRec i = fromString $ "arg_" <> show j <> "_" <> bname <> "_t" + where j = if noWeightRec then i else i+1 -- | Compute the (Program) Type that corresponds to a tensor tiToType :: TensorInfo -> Type