Skip to content

Commit

Permalink
Fix problems with Expand.
Browse files Browse the repository at this point in the history
This patch fixes a bug in Expand where (array expanded) expressions
are erroneously floated past let bindings. It also rectifies a case
where an expression is not floated out of a loop because the loop is
mistaken as a cheap expression where floating should be disabled. It
also avoids unnecessary inlining of let bindings (which can lead to
introduction of redundant computations) when no expression is
floated out of the let body.
  • Loading branch information
kffaxen authored and pjonsson committed Mar 14, 2021
1 parent 12cf9e0 commit 6f03cdd
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 63 deletions.
40 changes: 24 additions & 16 deletions src/Feldspar/Core/Middleend/Expand.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@
-- FIXME: Fix the incomplete patterns.
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}

-- Handle situations where an expression in a loop nest depends on the innermost loop
-- (so ordinary loop invariant removal does not help) but is invariant with respect to
-- some outer loop. To this end we expand the value to an array with dimensions as the
-- inner loops the expression depends on. The (array) expresion can now be floated
-- out of the inner loops (because of the array expansion) as well as the outer invariant
-- loop(s).
-- This transformation is important since the unrestricted inling by the vector library
-- in combination with the embedding mechanism often creates this kind of loop nests.

module Feldspar.Core.Middleend.Expand (expand) where

import Feldspar.Core.UntypedRepresentation
Expand Down Expand Up @@ -84,7 +93,7 @@ expF ai vm e = mdo (s, (bs1,e1)) <- ea aiNew vm e
let arrVE = In r $ Variable arrV
b = BI {biAbsI = map snd profLoops, biBindIs = bs3, biBind = (arrV, flArr)}
refE = In (getAnnotation e) $ App GetIx (typeof e) [arrVE, idxE]
return (s, if null keepLoops || not (sharable e1) || simpleExp ai vm e1
return (s, if null keepLoops || simpleExp ai vm e1
then (bs2, e2) -- We should not float e
else ([b], refE)) -- We float an expanded e in b

Expand Down Expand Up @@ -147,17 +156,18 @@ eu :: [AbsInfo] -> VarMap -> RExp -> Rename (S.Set Var, ([BindInfo], RExp))
eu _ vm (Variable v) = let (s,e) = vm M.! v in return (s, ([],e))
eu _ _ (Literal l) = return (S.empty, ([], Literal l))
eu ai vm (App Let t [eRhs, In r (Lambda v eBody)])
= do (fvsR, bseR) <- expE ai vm eRhs
let (bsR, eR) = bseR -- Note [Lazy binding]
inline = not (sharable eR) || simpleArrRef eR
eR1 = if inline then dropAnnotation eR else Variable v
vmB = M.insert v (fvsR, eR1) vm
aiB = if inline then ai else AbsI (S.singleton v) : ai
(fvsB, bseB) <- expE aiB vmB eBody
let (bsB, eB) = bseB
eNew = App Let t [eR, In r $ Lambda v eB]
bsB1 = if inline then bsB else shiftBIs bsB
return (fvsB, (bsR ++ bsB1, if inline then dropAnnotation eB else eNew))
= mdo (fvsR, bseR) <- expE ai vm eRhs
let (bsR, eR) = bseR -- Note [Lazy binding]
tryInline = simpleExp ai vm eR
doInline = tryInline && not (null bsB) -- Inline only if we generate any bindings
eR1 = if doInline then dropAnnotation eR else Variable v
vmB = M.insert v (fvsR, eR1) vm
aiB = if tryInline then ai else AbsI (S.singleton v) : ai
(fvsB, bseB) <- expE aiB vmB eBody
let (bsB, eB) = bseB
eNew = App Let t [eR, In r $ Lambda v eB]
bsB1 = if tryInline then bsB else shiftBIs bsB
return (fvsB, (bsR ++ bsB1, if doInline then dropAnnotation eB else eNew))
eu ai vm (App op t [eLen, eInit, In r1 (Lambda vIx (In r2 (Lambda vSt eBody)))])
| op `elem` [ForLoop, Sequential]
= do (fvsL, bseL) <- expE ai vm eLen
Expand Down Expand Up @@ -203,13 +213,11 @@ eu ai vm (Lambda v e)
return (fvs S.\\ vs, (shiftBIs $ fst bse, Lambda v $ snd bse))

simpleExp :: [AbsInfo] -> VarMap -> UExp -> Bool
simpleExp ai vm e = simpleArrRef e || expCost ai vm e <= 2
simpleExp ai vm e = not (goodToShare e) || simpleArrRef e || expCost ai vm e <= 2

expCost :: [AbsInfo] -> VarMap -> UExp -> Int
expCost ai vm (In _ e) = go e
where go (Variable _) = 0
go (Literal _) = 0
go (App op _ es) = appCost ai vm op es
where go (App op _ es) = appCost ai vm op es
go _ = 5

appCost :: [AbsInfo] -> VarMap -> Op -> [UExp] -> Int
Expand Down
7 changes: 5 additions & 2 deletions tests/gold/concatV.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ void concatV(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
uint32_t len0;
struct awl_signedS32 v6 = { 0 };
uint32_t v11;
struct awl_signedS32 v8 = { 0 };
uint32_t v9;
uint32_t len1;
struct awl_signedS32 e2 = { 0 };
Expand All @@ -18,7 +19,8 @@ void concatV(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
for (uint32_t v5 = 0; v5 < len0; v5 += 1)
{
v11 = (v26).length;
v9 = ((*v1).buffer[v5]).length;
v8 = (*v1).buffer[v5];
v9 = (v8).length;
len1 = (v11 + v9);
(v6).buffer = initArray((v6).buffer, (v6).length, sizeof(int32_t), len1);
(v6).length = len1;
Expand All @@ -28,7 +30,7 @@ void concatV(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
}
for (uint32_t v20 = 0; v20 < v9; v20 += 1)
{
(v6).buffer[(v20 + v11)] = ((*v1).buffer[v5]).buffer[v20];
(v6).buffer[(v20 + v11)] = (v8).buffer[v20];
}
e2 = v26;
v26 = v6;
Expand All @@ -43,4 +45,5 @@ void concatV(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
}
freeArray((v26).buffer);
freeArray((v6).buffer);
freeArray((v8).buffer);
}
7 changes: 5 additions & 2 deletions tests/gold/concatVM.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ void concatVM(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
struct awl_signedS32 e1 = { 0 };
struct awl_signedS32 v7 = { 0 };
uint32_t v12;
struct awl_signedS32 v9 = { 0 };
uint32_t v10;
uint32_t len2;
struct awl_signedS32 e3 = { 0 };
Expand All @@ -18,7 +19,8 @@ void concatVM(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
for (uint32_t v6 = 0; v6 < len0; v6 += 1)
{
v12 = (e1).length;
v10 = ((*v1).buffer[v6]).length;
v9 = (*v1).buffer[v6];
v10 = (v9).length;
len2 = (v12 + v10);
(v7).buffer = initArray((v7).buffer, (v7).length, sizeof(int32_t), len2);
(v7).length = len2;
Expand All @@ -28,12 +30,13 @@ void concatVM(struct awl_awl_signedS32 * v1, struct awl_signedS32 * out)
}
for (uint32_t v21 = 0; v21 < v10; v21 += 1)
{
(v7).buffer[(v21 + v12)] = ((*v1).buffer[v6]).buffer[v21];
(v7).buffer[(v21 + v12)] = (v9).buffer[v21];
}
e3 = e1;
e1 = v7;
v7 = e3;
}
*out = e1;
freeArray((v7).buffer);
freeArray((v9).buffer);
}
14 changes: 10 additions & 4 deletions tests/gold/metrics.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ void metrics(struct awl_signedS32 * v1, struct awl_signedS32 * v2, struct awl_aw
uint32_t v10;
uint32_t v9;
struct awl_awl_signedS32 v33 = { 0 };
struct awl_s_2_2xunsignedS32 v16 = { 0 };
uint32_t v18;
struct awl_signedS32 st0 = { 0 };
struct awl_signedS32 * v14 = NULL;
struct awl_signedS32 v36 = { 0 };
uint32_t v37;

v10 = (*v3).length;
Expand All @@ -59,27 +61,31 @@ void metrics(struct awl_signedS32 * v1, struct awl_signedS32 * v2, struct awl_aw
(v33).length = v10;
for (uint32_t v13 = 0; v13 < v10; v13 += 1)
{
v18 = min(((*v3).buffer[v13]).length, v9);
v16 = (*v3).buffer[v13];
v18 = min((v16).length, v9);
((v33).buffer[v13]).buffer = initArray(((v33).buffer[v13]).buffer, ((v33).buffer[v13]).length, sizeof(int32_t), v18);
((v33).buffer[v13]).length = v18;
for (uint32_t v24 = 0; v24 < v18; v24 += 1)
{
((v33).buffer[v13]).buffer[v24] = (*v14).buffer[(((*v3).buffer[v13]).buffer[v24]).member1];
((v33).buffer[v13]).buffer[v24] = (*v14).buffer[((v16).buffer[v24]).member1];
}
v14 = &(v33).buffer[v13];
}
(*out).buffer = initArray_awl_signedS32((*out).buffer, (*out).length, v10);
(*out).length = v10;
for (uint32_t v34 = 0; v34 < v10; v34 += 1)
{
v37 = ((v33).buffer[v34]).length;
v36 = (v33).buffer[v34];
v37 = (v36).length;
((*out).buffer[v34]).buffer = initArray(((*out).buffer[v34]).buffer, ((*out).buffer[v34]).length, sizeof(int32_t), v37);
((*out).buffer[v34]).length = v37;
for (uint32_t v40 = 0; v40 < v37; v40 += 1)
{
((*out).buffer[v34]).buffer[v40] = ((v33).buffer[v34]).buffer[v40];
((*out).buffer[v34]).buffer[v40] = (v36).buffer[v40];
}
}
freeArray_awl_signedS32((v33).buffer, (v33).length);
freeArray((v16).buffer);
freeArray((st0).buffer);
freeArray((v36).buffer);
}
75 changes: 36 additions & 39 deletions tests/gold/tfModel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,44 +77,41 @@
EparFor {e1xf32 | [*,*] :> [[*,*]]}
v174
\ v184 : 1xu32 ->
let
v194 : 1xu32 = ShiftLU {1xu32 | [*,*]} v184 2
in
EparFor {e1xf32 | [*,*] :> [[*,*]]}
4
\ v30195 : 1xu32 ->
EWrite {e1xf32 | [*,*] :> [[*,*]]}
(Add {1xu32 | [*,*]} v194 v30195)
EparFor {e1xf32 | [*,*] :> [[*,*]]}
4
\ v30195 : 1xu32 ->
EWrite {e1xf32 | [*,*] :> [[*,*]]}
(Add {1xu32 | [*,*]} (ShiftLU {1xu32 | [*,*]} v184 2) v30195)
(Condition {1xf32 | [*,*]}
(LTH {1xbool | [*,*]} v30195 3)
(Condition {1xf32 | [*,*]}
(LTH {1xbool | [*,*]} v30195 3)
(LTH {1xbool | [*,*]} v30195 2)
(Condition {1xf32 | [*,*]}
(LTH {1xbool | [*,*]} v30195 2)
(Condition {1xf32 | [*,*]}
(LTH {1xbool | [*,*]} v30195 1)
(GetIx {1xf32 | [*,*]}
v1
(Quot {1xu32 | [*,*]} (Add {1xu32 | [*,*]} v184 v30195) 1))
(GetIx {1xf32 | [*,*]}
v2
(Quot {1xu32 | [*,*]}
(Add {1xu32 | [*,*]}
v184
(GetIx {1xu32 | [*,*]} a10098 v30195))
1)))
(LTH {1xbool | [*,*]} v30195 1)
(GetIx {1xf32 | [*,*]}
v3
v1
(Quot {1xu32 | [*,*]} (Add {1xu32 | [*,*]} v184 v30195) 1))
(GetIx {1xf32 | [*,*]}
v2
(Quot {1xu32 | [*,*]}
(Add {1xu32 | [*,*]}
v184
(GetIx {1xu32 | [*,*]} a10108 v30195))
(GetIx {1xu32 | [*,*]} a10098 v30195))
1)))
(GetIx {1xf32 | [*,*]}
v4
v3
(Quot {1xu32 | [*,*]}
(Add {1xu32 | [*,*]}
v184
(GetIx {1xu32 | [*,*]} a10118 v30195))
1))))
(GetIx {1xu32 | [*,*]} a10108 v30195))
1)))
(GetIx {1xf32 | [*,*]}
v4
(Quot {1xu32 | [*,*]}
(Add {1xu32 | [*,*]}
v184
(GetIx {1xu32 | [*,*]} a10118 v30195))
1))))
v26 : a[30:30]1xf32 =
[0.9678828,-7.429327e-2,-1.4612858,-0.46043652,-8.7886825e-2,-0.26906747,-8.205062e-2,0.13083398,-0.107215166,-1.136796,-0.10522478,0.24648522,0.36117786,0.20214033,-0.1043409,-0.5552198,-0.59523,0.29417604,6.126838e-2,0.32809013,-0.2514663,0.12271329,0.13569178,-1.3657666,5.9321046e-2,-0.5515968,0.6249279,0.47829178,0.44494593,-0.50133526]
v21 : a[10:10]1xf32 =
Expand Down Expand Up @@ -254,7 +251,6 @@
\ v305 : 1xu32 ->
let
v308 : 1xu32 = Mul {1xu32 | [*,*]} v305 3
v307 : 1xu32 = ShiftLU {1xu32 | [*,*]} v305 2
a10218 : a[10:10]1xf32 =
EMaterialize {a[10:10]1xf32 | [*,10] :> [[*,*]]}
10
Expand Down Expand Up @@ -282,7 +278,9 @@
Add {1xf32 | [*,*]}
v322
(Mul {1xf32 | [*,*]}
(GetIx {1xf32 | [*,*]} v238 (Add {1xu32 | [*,*]} v307 v20321))
(GetIx {1xf32 | [*,*]}
v238
(Add {1xu32 | [*,*]} (ShiftLU {1xu32 | [*,*]} v305 2) v20321))
(GetIx {1xf32 | [*,*]}
a10186
(Add {1xu32 | [*,44]}
Expand Down Expand Up @@ -393,16 +391,15 @@
0
(\ v10386 : 1xu32 ->
\ v387 : 1xu32 ->
Condition {1xu32 | [*,*]}
(GTH {1xbool | [*,*]}
(GetIx {1xf32 | [*,*]}
v364
(Add {1xu32 | [*,*]}
v385
(GetIx {1xu32 | [1,*]} a10258 v10386)))
(GetIx {1xf32 | [*,*]} v364 (Add {1xu32 | [*,*]} v385 v387)))
(GetIx {1xu32 | [1,*]} a10258 v10386)
v387))))
let
v388 : 1xu32 = GetIx {1xu32 | [1,*]} a10258 v10386
in
Condition {1xu32 | [*,*]}
(GTH {1xbool | [*,*]}
(GetIx {1xf32 | [*,*]} v364 (Add {1xu32 | [*,*]} v385 v388))
(GetIx {1xf32 | [*,*]} v364 (Add {1xu32 | [*,*]} v385 v387)))
v388
v387))))
in
DivFrac {1xf32 | [*,*]}
(ForLoop {1xf32 | [*,*]}
Expand Down

0 comments on commit 6f03cdd

Please sign in to comment.