commit 3f147ed2e078769a71b2935fc36cb08b2b0ddb67
parent 8fdea1dd5236f2693b677fc6bd6e2bb417c0fccd
Author: Quentin Carbonneaux <quentin.carbonneaux@yale.edu>
Date: Sun, 4 Sep 2016 20:22:38 -0400
implement a simple load elimination pass
Diffstat:
M | Makefile | | | 2 | +- |
M | all.h | | | 6 | +++++- |
A | load.c | | | 408 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
3 files changed, 414 insertions(+), 2 deletions(-)
diff --git a/Makefile b/Makefile
@@ -4,7 +4,7 @@ ABI = sysv
V = @
OBJDIR = obj
-SRC = main.c util.c parse.c cfg.c mem.c ssa.c alias.c copy.c fold.c live.c $(ABI).c isel.c spill.c rega.c emit.c
+SRC = main.c util.c parse.c cfg.c mem.c ssa.c alias.c load.c copy.c fold.c live.c $(ABI).c isel.c spill.c rega.c emit.c
OBJ = $(SRC:%.c=$(OBJDIR)/%.o)
CFLAGS += -Wall -Wextra -std=c99 -g -pedantic
diff --git a/all.h b/all.h
@@ -530,7 +530,8 @@ void bsdiff(BSet *, BSet *);
int bsequal(BSet *, BSet *);
int bsiter(BSet *, int *);
-static inline int bshas(BSet *bs, uint elt)
+static inline int
+bshas(BSet *bs, uint elt)
{
assert(elt < bs->nt * NBit);
return (bs->t[elt/NBit] & BIT(elt%NBit)) != 0;
@@ -561,6 +562,9 @@ void fillalias(Fn *);
int alias(Ref, int, Ref, int, int *, Fn *);
int escapes(Ref, Fn *);
+/* load.c */
+void loadopt(Fn *);
+
/* ssa.c */
void filluse(Fn *);
void fillpreds(Fn *);
diff --git a/load.c b/load.c
@@ -0,0 +1,408 @@
+#include "all.h"
+
+#define MASK(w) (BIT(8*(w)-1)*2-1) /* must work when w==8 */
+
+typedef struct Loc Loc;
+typedef struct Slice Slice;
+typedef struct Insert Insert;
+
+
+struct Loc {
+ enum {
+ LRoot, /* right above the original load */
+ LLoad, /* inserting a load is allowed */
+ LNoLoad, /* only scalar operations allowed */
+ } type;
+ uint off;
+ Blk *blk;
+};
+
+struct Slice {
+ Ref ref;
+ short sz;
+ short cls; /* load class */
+};
+
+struct Insert {
+ uint isphi:1;
+ uint num:31;
+ int bid;
+ uint off;
+ union {
+ Ins ins;
+ struct {
+ Slice m;
+ Phi *p;
+ } phi;
+ } new;
+};
+
+static Fn *curf;
+static uint inum; /* current insertion number */
+static Insert *ilog; /* global insertion log */
+static uint nlog; /* number of entries in the log */
+
+static int
+loadsz(Ins *l)
+{
+ switch (l->op) {
+ case Oloadsb: case Oloadub: return 1;
+ case Oloadsh: case Oloaduh: return 2;
+ case Oloadsw: case Oloaduw: return 4;
+ case Oload: return KWIDE(l->cls) ? 8 : 4;
+ }
+ die("unreachable");
+}
+
+static int
+storesz(Ins *s)
+{
+ switch (s->op) {
+ case Ostoreb: return 1;
+ case Ostoreh: return 2;
+ case Ostorew: case Ostores: return 4;
+ case Ostorel: case Ostored: return 8;
+ }
+ die("unreachable");
+}
+
+static Ref
+iins(int cls, int op, Ref a0, Ref a1, Loc *l)
+{
+ Insert *ist;
+
+ vgrow(&ilog, ++nlog);
+ ist = &ilog[nlog-1];
+ ist->isphi = 0;
+ ist->num = inum++;
+ ist->bid = l->blk->id;
+ ist->off = l->off;
+ ist->new.ins = (Ins){op, R, {a0, a1}, cls};
+ return ist->new.ins.to = newtmp("ld", cls, curf);
+}
+
+static void
+cast(Ref *r, int cls, Loc *l)
+{
+ int cls0;
+
+ if (rtype(*r) == RCon)
+ return;
+ assert(rtype(*r) == RTmp);
+ cls0 = curf->tmp[r->val].cls;
+ if (cls0 == cls || (cls == Kw && cls0 == Kl))
+ return;
+ assert(!KWIDE(cls0) || KWIDE(cls));
+ if (KWIDE(cls) == KWIDE(cls0))
+ *r = iins(cls, Ocast, *r, R, l);
+ else {
+ assert(cls == Kl);
+ if (cls0 == Ks)
+ *r = iins(Kw, Ocast, *r, R, l);
+ *r = iins(Kl, Oextuw, *r, R, l);
+ }
+}
+
+static inline void
+mask(int cls, Ref *r, bits msk, Loc *l)
+{
+ cast(r, cls, l);
+ *r = iins(cls, Oand, *r, getcon(msk, curf), l);
+}
+
+static Ref
+load(Slice sl, bits msk, Loc *l)
+{
+ Ref r;
+ int ld, cls, all;
+
+ ld = (int[]){
+ [1] = Oloadub,
+ [2] = Oloaduh,
+ [4] = Oloaduw,
+ [8] = Oload
+ }[sl.sz];
+ all = msk == MASK(sl.sz);
+ if (all)
+ cls = sl.cls;
+ else
+ cls = sl.sz > 4 ? Kl : Kw;
+ r = iins(cls, ld, sl.ref, R, l);
+ if (!all)
+ mask(cls, &r, msk, l);
+ return r;
+}
+
+/* returns a ref containing the contents of the slice
+ * passed as argument, all the bits set to 0 in the
+ * mask argument are zeroed in the result;
+ * the returned ref has an integer class when the
+ * mask does not cover all the bits of the slice,
+ * otherwise, it has class sl.cls
+ * the procedure returns R when it fails */
+static Ref
+def(Slice sl, bits msk, Blk *b, Ins *i, Loc *il)
+{
+ Blk *bp;
+ bits msk1, msks;
+ int off, cls, cls1, op, sz, ld;
+ uint np, oldl, oldt;
+ Ref r, r1;
+ Phi *p;
+ Insert *ist;
+ Loc l;
+
+ /* invariants:
+ * -1- b dominates il->blk; so we can use
+ * temporaries of b in il->blk
+ * -2- if il->type != LNoLoad, then il->blk
+ * postdominates the original load; so it
+ * is safe to load in il->blk
+ * -3- if il->type != LNoLoad, then b
+ * postdominates il->blk (and by 2, the
+ * original load)
+ */
+ assert(dom(b, il->blk));
+ oldl = nlog;
+ oldt = curf->ntmp;
+ if (0) {
+ Load:
+ curf->ntmp = oldt;
+ nlog = oldl;
+ if (il->type != LLoad)
+ return R;
+ return load(sl, msk, il);
+ }
+
+ if (!i)
+ i = &b->ins[b->nins];
+ cls = sl.sz > 4 ? Kl : Kw;
+ msks = MASK(sl.sz);
+
+ while (i > b->ins) {
+ --i;
+ if (req(i->to, sl.ref)
+ || (i->op == Ocall && escapes(sl.ref, curf)))
+ goto Load;
+ ld = isload(i->op);
+ if (ld) {
+ sz = loadsz(i);
+ r1 = i->arg[0];
+ r = i->to;
+ } else if (isstore(i->op)) {
+ sz = storesz(i);
+ r1 = i->arg[1];
+ r = i->arg[0];
+ } else
+ continue;
+ switch (alias(sl.ref, sl.sz, r1, sz, &off, curf)) {
+ case MustAlias:
+ if (off < 0) {
+ off = -off;
+ msk1 = (MASK(sz) << 8*off) & msks;
+ op = Oshl;
+ } else {
+ msk1 = (MASK(sz) >> 8*off) & msks;
+ op = Oshr;
+ }
+ if ((msk1 & msk) == 0)
+ break;
+ if (off) {
+ cls1 = cls;
+ if (op == Oshr && off + sl.sz > 4)
+ cls1 = Kl;
+ cast(&r, cls1, il);
+ r1 = getcon(8*off, curf);
+ r = iins(cls1, op, r, r1, il);
+ }
+ if ((msk1 & msk) != msk1 || off + sz < sl.sz)
+ mask(cls, &r, msk1 & msk, il);
+ if ((msk & ~msk1) != 0) {
+ r1 = def(sl, msk & ~msk1, b, i, il);
+ if (req(r1, R))
+ goto Load;
+ r = iins(cls, Oor, r, r1, il);
+ }
+ if (msk == msks)
+ cast(&r, sl.cls, il);
+ return r;
+ case MayAlias:
+ if (ld)
+ break;
+ else
+ goto Load;
+ case NoAlias:
+ break;
+ default:
+ die("unreachable");
+ }
+ }
+
+ for (ist=ilog; ist<&ilog[nlog]; ++ist)
+ if (ist->isphi && ist->bid == b->id)
+ if (req(ist->new.phi.m.ref, sl.ref))
+ if (ist->new.phi.m.sz == sl.sz) {
+ r = ist->new.phi.p->to;
+ if (msk != msks)
+ mask(cls, &r, msk, il);
+ else
+ cast(&r, sl.cls, il);
+ return r;
+ }
+
+ for (p=b->phi; p; p=p->link)
+ if (req(p->to, sl.ref))
+ /* scanning predecessors in that
+ * case would be unsafe */
+ goto Load;
+
+ if (b->npred == 0)
+ goto Load;
+ if (b->npred == 1) {
+ bp = b->pred[0];
+ assert(bp->loop == il->blk->loop);
+ l = *il;
+ if (bp->s2)
+ l.type = LNoLoad;
+ r1 = def(sl, msk, bp, 0, &l);
+ if (req(r1, R))
+ goto Load;
+ return r1;
+ }
+
+ r = newtmp("ld", sl.cls, curf);
+ p = alloc(sizeof *p);
+ vgrow(&ilog, ++nlog);
+ ist = &ilog[nlog-1];
+ ist->isphi = 1;
+ ist->bid = b->id;
+ ist->new.phi.m = sl;
+ ist->new.phi.p = p;
+ p->to = r;
+ p->cls = sl.cls;
+ p->narg = b->npred;
+ for (np=0; np<b->npred; ++np) {
+ bp = b->pred[np];
+ if (!bp->s2
+ && il->type != LNoLoad
+ && bp->loop < il->blk->loop)
+ l.type = LLoad;
+ else
+ l.type = LNoLoad;
+ l.blk = bp;
+ l.off = bp->nins;
+ r1 = def(sl, msks, bp, 0, &l);
+ if (req(r1, R))
+ goto Load;
+ p->arg[np] = r1;
+ p->blk[np] = bp;
+ }
+ if (msk != msks)
+ mask(cls, &r, msk, il);
+ return r;
+}
+
+static int
+icmp(const void *pa, const void *pb)
+{
+ Insert *a, *b;
+ int c;
+
+ a = (Insert *)pa;
+ b = (Insert *)pb;
+ if ((c = a->bid - b->bid))
+ return c;
+ if (a->isphi && b->isphi)
+ return 0;
+ if (a->isphi)
+ return -1;
+ if (b->isphi)
+ return +1;
+ if ((c = a->off - b->off))
+ return c;
+ return a->num - b->num;
+}
+
+/* require rpo ssa alias */
+void
+loadopt(Fn *fn)
+{
+ Ins *i, *ib;
+ Blk *b;
+ int n, sz;
+ uint ni, ext, nt;
+ Insert *ist;
+ Slice sl;
+ Loc l;
+
+ curf = fn;
+ ilog = vnew(0, sizeof ilog[0], emalloc);
+ nlog = 0;
+ inum = 0;
+ for (b=fn->start; b; b=b->link)
+ for (i=b->ins; i<&b->ins[b->nins]; ++i) {
+ if (!isload(i->op))
+ continue;
+ sz = loadsz(i);
+ sl = (Slice){i->arg[0], sz, i->cls};
+ l = (Loc){LRoot, i-b->ins, b};
+ i->arg[1] = def(sl, MASK(sz), b, i, &l);
+ }
+ qsort(ilog, nlog, sizeof ilog[0], icmp);
+ vgrow(&ilog, nlog+1);
+ ilog[nlog].bid = fn->nblk; /* add a sentinel */
+ ib = vnew(0, sizeof(Ins), emalloc);
+ for (ist=ilog, n=0; n<fn->nblk; ++n) {
+ b = fn->rpo[n];
+ for (; ist->bid == n && ist->isphi; ++ist) {
+ ist->new.phi.p->link = b->phi;
+ b->phi = ist->new.phi.p;
+ }
+ ni = 0;
+ nt = 0;
+ for (;;) {
+ if (ist->bid == n && ist->off == ni)
+ i = &ist++->new.ins;
+ else {
+ if (ni == b->nins)
+ break;
+ i = &b->ins[ni++];
+ if (isload(i->op)
+ && !req(i->arg[1], R)) {
+ ext = Oextsb + i->op - Oloadsb;
+ switch (i->op) {
+ default:
+ die("unreachable");
+ case Oloadsb:
+ case Oloadub:
+ case Oloadsh:
+ case Oloaduh:
+ i->op = ext;
+ break;
+ case Oloadsw:
+ case Oloaduw:
+ if (i->cls == Kl) {
+ i->op = ext;
+ break;
+ }
+ case Oload:
+ i->op = Ocopy;
+ break;
+ }
+ i->arg[0] = i->arg[1];
+ i->arg[1] = R;
+ }
+ }
+ vgrow(&ib, ++nt);
+ ib[nt-1] = *i;
+ }
+ b->nins = nt;
+ idup(&b->ins, ib, nt);
+ }
+ vfree(ib);
+ vfree(ilog);
+ if (debug['M']) {
+ fprintf(stderr, "\n> After load elimination:\n");
+ printfn(fn, stderr);
+ }
+}