Skip to content

Commit

Permalink
Make resnet18 work again.
Browse files Browse the repository at this point in the history
  • Loading branch information
kffaxen authored and pjonsson committed Mar 17, 2021
1 parent a55d3c0 commit 499e4e4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 36 deletions.
36 changes: 13 additions & 23 deletions src/Feldspar/Onnx/Operators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -263,36 +263,19 @@ onnxBatchNormalization attrs xs gamma beta mean var = ys
epsilon = value $ P.realToFrac $ getAttr attrs aaFloat 1e-5 "epsilon"

-- | Flatten a tensor to a matrix
-- We use Push vectors for flattening even though fusion is lost since flattening based on
-- We use Manifest vectors for flattening even though fusion is lost since flattening based on
-- Pull vectors leads to index expressions containing division and modulus that have a greater
-- performance impact than the loss of fusion.
onnxFlatten :: Syntax a => Attrs -> Pull sh a -> Pull DIM2 a
onnxFlatten attrs vec = toPull $ store $ flatPush (P.fromIntegral $ getAttr attrs aaInt 1 "axis") $ toPush vec

onnxFlatten' :: (Storable vec, Syntax a) => Attrs -> vec a -> Pull DIM2 a
onnxFlatten' attrs xs = toPull $ flatMan d $ store xs
onnxFlatten :: (Storable vec, Syntax a) => Attrs -> vec a -> Manifest DIM2 a
onnxFlatten attrs xs = flatMan d $ store xs
where d = P.fromIntegral $ getAttr attrs aaInt 1 "axis"

-- | Flatten a Manifest vector to two dimensions by changing its extent
flatMan :: Int -> Manifest sh a -> Manifest DIM2 a
flatMan d (Manifest arr sh) = Manifest arr sh'
where sh' = Z :. P.product ls :. P.product rs
(ls,rs) = takeDropShape d sh

-- | Flattening a Push vector to two dimensions
flatPush :: forall sh a . Int -> Push sh a -> Push DIM2 a
flatPush i (Push ixf ext) = Push ixf' $ Z :. P.product ls :. P.product rs
where ixf' :: PushK DIM2 a
ixf' wf = ixf (\ sh d -> let (ils,irs) = takeDropShape i sh
in wf (Z :. idxExp ls ils :. idxExp rs irs) d)
(ls,rs) = takeDropShape i ext

-- | Linearizing an index where both intex and extent are represented as lists
idxExp :: [Data Length] -> [Data Length] -> Data Length
idxExp ext idx = f (P.reverse ext) (P.reverse idx)
where f (n:ns) (i:is) = f ns is * n + i
f [] [] = 0
f _ _ = P.error "Operators.idxExp: extent and index differ in length"

-- | Split a shape
takeDropShape :: Int -> Shape sh -> ([Data Length], [Data Length])
takeDropShape i sh = P.splitAt j es
Expand All @@ -301,8 +284,8 @@ takeDropShape i sh = P.splitAt j es

-- | Matrix multiplication of two dimensional temsors
onnxGemm :: (RealFloat a, Numeric a, Pully vec, VecShape vec ~ DIM2, Pully vec2,
UnionShape DIM2 (VecShape vec2) ~ DIM2, Storable vec)
=> Attrs -> vec (Data a) -> vec (Data a) -> vec2 (Data a) -> DPull DIM2 a
UnionShape DIM2 (VecShape vec2) ~ DIM2, Storable vec, Pully vec', Storable vec', VecShape vec' ~ DIM2)
=> Attrs -> vec (Data a) -> vec' (Data a) -> vec2 (Data a) -> DPull DIM2 a
onnxGemm attrs vA vB vC = bcZipWith (+) (mmT vAT vBnT) $ toPull vC
where vA' = if alpha P.== 1.0 then toPull $ store vA else map (* value alpha) $ toPull $ store vA
vAT = if transA P.== 1 then transpose vA' else vA'
Expand Down Expand Up @@ -590,3 +573,10 @@ setSizeManifest3 :: Length -> Length -> Length -> Manifest DIM3 a -> Manifest DI
setSizeManifest3 s1 s2 s3 (Manifest arr (Z :. e1 :. e2 :. e3))
= Manifest (cap (singletonRange (s1*s2*s3) :> universal) arr)
(Z :. cap (singletonRange s1) e1 :. cap (singletonRange s2) e2 :. cap (singletonRange s3) e3)

-- | Add range info to four dimensinal Manifest vector
setSizeManifest4 :: Length -> Length -> Length -> Length -> Manifest DIM4 a -> Manifest DIM4 a
setSizeManifest4 s1 s2 s3 s4 (Manifest arr (Z :. e1 :. e2 :. e3 :. e4))
= Manifest (cap (singletonRange (s1*s2*s3*s4) :> universal) arr)
(Z :. cap (singletonRange s1) e1 :. cap (singletonRange s2) e2 :. cap (singletonRange s3) e3
:. cap (singletonRange s4) e4)
29 changes: 16 additions & 13 deletions src/Onnx/OnnxToFeld.hs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ mkProgramFile gr initGroups sInits inputs multiUses shapes useNative
where name = mangle $ fromJust $ G.name gr
params = if null initGroups then inputPs else "(weights :: WeightRec)" : inputPs
inputPs = map mkParam inputs
accesses = map mkAccess $ concat initGroups
accesses = map (mkAccess useNative) $ concat initGroups
tEnv = M.fromList [(fromJust $ TP.name t, t) | t <- D.toList $ G.initializer gr]
mkParam ti = "(" <> mangle (vipName ti) <> "' :: " <> shTy (vipType ti) <> ")"
shTy (d,t) = "Manifest DIM" <> show d <> " (Data " <> showElemT t <> ")"
Expand All @@ -214,7 +214,7 @@ mkProgramFile gr initGroups sInits inputs multiUses shapes useNative

-- | Make a possibly shape contraining input binding
mkInputCap :: V.ValueInfoProto -> Maybe String -> L.ByteString
mkInputCap vi msh = mangle (vipName vi) <> " = toPull $ " <> go msh <> mangle (vipName vi) <> "'"
mkInputCap vi msh = mangle (vipName vi) <> " = " <> go msh <> mangle (vipName vi) <> "'"
where go Nothing = ""
go (Just s) = "setSizeManifest" <> show (length sh) <> " " <> L.unwords (map show sh) <> " "
where sh = read $ "[" <> s <> "]" :: [Int]
Expand Down Expand Up @@ -251,10 +251,11 @@ shapeToDim :: TP.TensorProto -> Int64
shapeToDim p = TP.dims p `D.index` 0

-- | Read from the weight record
mkAccess :: TensorInfo -> L.ByteString
mkAccess ti = vname <> " = " <> setS <> " $ sel (Proxy @" <> show (tiField ti) <> ") weights ! (Z :. " <> show (tiIdx ti) <> ")"
mkAccess :: Bool -> TensorInfo -> L.ByteString
mkAccess useSize ti = vname <> " = " <> setS <> "sel (Proxy @" <> show (tiField ti) <> ") weights ! (Z :. " <> show (tiIdx ti) <> ")"
where vname = mangle $ tiName ti
setS = "setSizePull" <> show (length dims) <> L.concat (map (\d -> " " <> show d) dims)
setS | useSize = "setSizePull" <> show (length dims) <> L.concat (map (\d -> " " <> show d) dims) <> " $ "
| otherwise = ""
dims = D.toList $ tiDims ti

-- | Initialize a small tensor
Expand Down Expand Up @@ -334,7 +335,7 @@ mkMainFile bname hf weightFile weightRecTC outputs inputs noWeightRec useNative
, " }"
, ""
]
++ wread ++ concat (zipWith (mkArgRead bname useNative) inputs [1..]) ++
++ wread ++ concat (zipWith (mkArgRead ebname noWeightRec useNative) inputs [1..]) ++
[ " " <> ot <> " " <> ov <> " = {0};"
, ""
, " " <> functionName <> "(" <> warg <> inArgs <> ", &" <> ov <> ");"
Expand All @@ -346,8 +347,9 @@ mkMainFile bname hf weightFile weightRecTC outputs inputs noWeightRec useNative
]
(ov, _, oCode) = mkOutput useNative outputs
inArgs = L.intercalate ", " ["&" <> mangle (vipName v) | v <- inputs]
functionName = fromString $ encodeFunctionName bname
ot = argumentType bname $ length inputs + 1
functionName = fromString ebname
ot = argumentType ebname noWeightRec $ length inputs + 1
ebname = encodeFunctionName bname
wread | noWeightRec = []
| otherwise = [" weight_rec_t * w = read_constants(\"" <> fromString weightFile <> "\");"
, ""
Expand All @@ -365,15 +367,15 @@ mkMainFile bname hf weightFile weightRecTC outputs inputs noWeightRec useNative
]

-- | Generate code to read the argument tensors from file
mkArgRead :: FilePath -> Bool -> V.ValueInfoProto -> Int -> [L.ByteString]
mkArgRead bname useNative vip i
mkArgRead :: FilePath -> Bool -> Bool -> V.ValueInfoProto -> Int -> [L.ByteString]
mkArgRead bname noWeightRec useNative vip i
= [ " FILE* " <> fname <> " = fopen(argv[" <> show i <> "], \"r\");"
, " if (" <> fname <> " == NULL) {"
, " fprintf(stderr, \"Could not open %s for reading.\\n\", argv[" <> show i <> "]);"
, " exit(1);"
, " }"
, ""
, " " <> argumentType bname i <> " " <> vname <> ";"
, " " <> argumentType bname noWeightRec i <> " " <> vname <> ";"
, ""
, allocReadTensor useNative fname vname n elemT
, " fclose(" <> fname <> ");"
Expand All @@ -384,8 +386,9 @@ mkArgRead bname useNative vip i
(n, elemT) = vipType vip

-- | Argument type name
argumentType :: FilePath -> Int -> L.ByteString
argumentType bname i = fromString $ "arg_" <> show i <> "_" <> bname <> "_t"
argumentType :: FilePath -> Bool -> Int -> L.ByteString
argumentType bname noWeightRec i = fromString $ "arg_" <> show j <> "_" <> bname <> "_t"
where j = if noWeightRec then i else i+1

-- | Compute the (Program) Type that corresponds to a tensor
tiToType :: TensorInfo -> Type
Expand Down

0 comments on commit 499e4e4

Please sign in to comment.