commit ebcc123e4dc0d5497e816fe824c1849685e295af
parent 45f3493777488b05c28746670f585c8e41a76681
Author: Quentin Carbonneaux <quentin.carbonneaux@yale.edu>
Date: Thu, 7 Apr 2016 13:08:31 -0400
add boring folding code
Diffstat:
| M | fold.c | | | 220 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- |
1 file changed, 212 insertions(+), 8 deletions(-)
diff --git a/fold.c b/fold.c
@@ -13,8 +13,6 @@ struct Edge {
Edge *work;
};
-int evalop(int, int, int, int);
-
static int *val;
static Edge *flowrk, (*edge)[2];
static Use **usewrk;
@@ -36,6 +34,8 @@ latval(Ref r)
return r.val;
case RType:
return Bot;
+ case -1:
+ return CON_Z.val;
default:
die("unreachable");
}
@@ -64,7 +64,7 @@ update(int t, int v, Fn *fn)
if (val[t] != v) {
tmp = &fn->tmp[t];
for (u=0; u<tmp->nuse; u++) {
- vgrow(usewrk, ++nuse);
+ vgrow(&usewrk, ++nuse);
usewrk[nuse-1] = &tmp->use[u];
}
}
@@ -89,6 +89,8 @@ visitphi(Phi *p, int n, Fn *fn)
update(p->to.val, v, fn);
}
+static int opfold(int, int, Con *, Con *, Fn *);
+
static void
visitins(Ins *i, Fn *fn)
{
@@ -96,9 +98,17 @@ visitins(Ins *i, Fn *fn)
if (rtype(i->to) != RTmp)
return;
- l = latval(i->arg[0]);
- r = latval(i->arg[1]);
- v = evalop(i->op, i->cls, l, r);
+ if (opdesc[i->op].cfold) {
+ l = latval(i->arg[0]);
+ r = latval(i->arg[1]);
+ if (l == Bot || r == Bot)
+ v = Bot;
+ else if (l == Top || r == Top)
+ v = Top;
+ else
+ v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
+ } else
+ v = Bot;
assert(v != Top);
update(i->to.val, v, fn);
}
@@ -164,8 +174,17 @@ fold(Fn *fn)
edge = emalloc(fn->nblk * sizeof edge[0]);
usewrk = vnew(0, sizeof usewrk[0]);
- for (n=0; n<fn->ntmp; n++)
- val[n] = Bot;
+ for (b=fn->start; b; b=b->link) {
+ for (p=b->phi; p; p=p->link)
+ val[p->to.val] = Top;
+ for (i=b->ins; i-b->ins < b->nins; i++)
+ if (rtype(i->to) == RTmp) {
+ if (opdesc[i->op].cfold)
+ val[i->to.val] = Top;
+ else
+ val[i->to.val] = Bot;
+ }
+ }
for (n=0; n<fn->nblk; n++) {
b = fn->rpo[n];
b->visit = 0;
@@ -229,3 +248,188 @@ fold(Fn *fn)
free(val);
free(edge);
}
+
+/* boring folding code */
+
+static void
+foldint(Con *res, int op, int w, Con *cl, Con *cr)
+{
+ union {
+ int64_t s;
+ uint64_t u;
+ float fs;
+ double fd;
+ } l, r;
+ uint64_t x;
+ char *lab;
+
+ lab = 0;
+ l.s = cl->bits.i;
+ r.s = cl->bits.i;
+ switch (op) {
+ case OAdd:
+ x = l.u + r.u;
+ if (cl->type == CAddr) {
+ if (cr->type == CAddr)
+ err("undefined addition (addr + addr)");
+ lab = cl->label;
+ }
+ else if (cr->type == CAddr)
+ lab = cr->label;
+ break;
+ case OSub:
+ x = l.u - r.u;
+ if (cl->type == CAddr) {
+ if (cr->type != CAddr)
+ lab = cl->label;
+ else if (strcmp(cl->label, cr->label) != 0)
+ err("undefined substraction (addr1 - addr2)");
+ }
+ else if (cr->type == CAddr)
+ err("undefined substraction (num - addr)");
+ break;
+ case ODiv: x = l.s / r.s; break;
+ case ORem: x = l.s % r.s; break;
+ case OUDiv: x = l.u / r.u; break;
+ case OURem: x = l.u % r.u; break;
+ case OMul: x = l.u * r.u; break;
+ case OAnd: x = l.u & r.u; break;
+ case OOr: x = l.u | r.u; break;
+ case OXor: x = l.u ^ r.u; break;
+ case OSar: x = l.s >> (r.u & 63); break;
+ case OShr: x = l.u >> (r.u & 63); break;
+ case OShl: x = l.u << (r.u & 63); break;
+ case OExtsb: x = (int8_t)l.u; break;
+ case OExtub: x = (uint8_t)l.u; break;
+ case OExtsh: x = (int16_t)l.u; break;
+ case OExtuh: x = (uint16_t)l.u; break;
+ case OExtsw: x = (int32_t)l.u; break;
+ case OExtuw: x = (uint32_t)l.u; break;
+ case OFtosi:
+ if (w)
+ x = (int64_t)cl->bits.d;
+ else
+ x = (int32_t)cl->bits.s;
+ break;
+ case OCast:
+ x = l.u;
+ if (cl->type == CAddr)
+ lab = cl->label;
+ break;
+ default:
+ if (OCmpw <= op && op <= OCmpl1) {
+ if (op <= OCmpw1) {
+ l.u = (uint32_t)l.u;
+ r.u = (uint32_t)r.u;
+ } else
+ op -= OCmpl - OCmpw;
+ switch (op - OCmpw) {
+ case ICule: x = l.u <= r.u; break;
+ case ICult: x = l.u < r.u; break;
+ case ICsle: x = l.s <= r.s; break;
+ case ICslt: x = l.s < r.s; break;
+ case ICsgt: x = l.s > r.s; break;
+ case ICsge: x = l.s >= r.s; break;
+ case ICugt: x = l.u > r.u; break;
+ case ICuge: x = l.u >= r.u; break;
+ case ICeq: x = l.u == r.u; break;
+ case ICne: x = l.u != r.u; break;
+ default: die("unreachable");
+ }
+ }
+ else if (OCmps <= op && op <= OCmps1) {
+ switch (op - OCmps) {
+ case FCle: x = l.fs <= r.fs; break;
+ case FClt: x = l.fs < r.fs; break;
+ case FCgt: x = l.fs > r.fs; break;
+ case FCge: x = l.fs >= r.fs; break;
+ case FCne: x = l.fs != r.fs; break;
+ case FCeq: x = l.fs == r.fs; break;
+ case FCo: x = l.fs < r.fs || l.fs >= r.fs; break;
+ case FCuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
+ default: die("unreachable");
+ }
+ }
+ else if (OCmpd <= op && op <= OCmpd1) {
+ switch (op - OCmpd) {
+ case FCle: x = l.fd <= r.fd; break;
+ case FClt: x = l.fd < r.fd; break;
+ case FCgt: x = l.fd > r.fd; break;
+ case FCge: x = l.fd >= r.fd; break;
+ case FCne: x = l.fd != r.fd; break;
+ case FCeq: x = l.fd == r.fd; break;
+ case FCo: x = l.fd < r.fd || l.fd >= r.fd; break;
+ case FCuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
+ default: die("unreachable");
+ }
+ }
+ else
+ die("unreachable");
+ }
+ *res = (Con){lab ? CAddr : CBits, .bits={.i=x}};
+ if (lab)
+ strcpy(res->label, lab);
+}
+
+static void
+foldflt(Con *res, int op, int w, Con *cl, Con *cr)
+{
+ float xs, ls, rs;
+ double xd, ld, rd;
+
+ if (w) {
+ ld = cl->bits.d;
+ rd = cr->bits.d;
+ switch (op) {
+ case OAdd: xd = ld + rd; break;
+ case OSub: xd = ld - rd; break;
+ case ODiv: xd = ld / rd; break;
+ case OMul: xd = ld * rd; break;
+ case OSitof: xd = cl->bits.i; break;
+ case OExts: xd = cl->bits.s; break;
+ case OCast: xd = cl->bits.d; break;
+ default: die("unreachable");
+ }
+ *res = (Con){CBits, .bits={.d=xd}, .flt=2};
+ } else {
+ ls = cl->bits.s;
+ rs = cr->bits.s;
+ switch (op) {
+ case OAdd: xs = ls + rs; break;
+ case OSub: xs = ls - rs; break;
+ case ODiv: xs = ls / rs; break;
+ case OMul: xs = ls * rs; break;
+ case OSitof: xs = cl->bits.i; break;
+ case OTruncd: xs = cl->bits.d; break;
+ case OCast: xs = cl->bits.s; break;
+ default: die("unreachable");
+ }
+ *res = (Con){CBits, .bits={.s=xs}, .flt=1};
+ }
+}
+
+static int
+opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
+{
+ int nc;
+ Con c;
+
+ if ((op == ODiv || op == OUDiv
+ || op == ORem || op == OURem) && czero(cr))
+ err("null divisor in '%s'", opdesc[op].name);
+ if (cls == Kw || cls == Kl)
+ foldint(&c, op, cls == Kl, cl, cr);
+ else {
+ if (cl->type != CBits || cr->type != CBits)
+ err("invalid address operand for '%s'", opdesc[op].name);
+ foldflt(&c, op, cls == Kd, cl, cr);
+ }
+ if (c.type == CBits)
+ nc = getcon(c.bits.i, fn).val;
+ else {
+ nc = fn->ncon;
+ vgrow(&fn->con, ++fn->ncon);
+ }
+ fn->con[nc] = c;
+ return nc;
+}