qbe

Internal scc patchset buffer for QBE
Log | Files | Refs | README | LICENSE

fold.c (11246B)


      1 #include "all.h"
      2 
      3 enum {
      4 	Bot = -2, /* lattice bottom */
      5 	Top = -1, /* lattice top */
      6 };
      7 
      8 typedef struct Edge Edge;
      9 
     10 struct Edge {
     11 	int dest;
     12 	int dead;
     13 	Edge *work;
     14 };
     15 
     16 static int *val;
     17 static Edge *flowrk, (*edge)[2];
     18 static Use **usewrk;
     19 static uint nuse;
     20 
     21 static int
     22 iscon(Con *c, int w, uint64_t k)
     23 {
     24 	if (c->type != CBits)
     25 		return 0;
     26 	if (w)
     27 		return (uint64_t)c->bits.i == k;
     28 	else
     29 		return (uint32_t)c->bits.i == (uint32_t)k;
     30 }
     31 
     32 static int
     33 latval(Ref r)
     34 {
     35 	switch (rtype(r)) {
     36 	case RTmp:
     37 		return val[r.val];
     38 	case RCon:
     39 		return r.val;
     40 	default:
     41 		die("unreachable");
     42 	}
     43 }
     44 
     45 static int
     46 latmerge(int v, int m)
     47 {
     48 	return m == Top ? v : (v == Top || v == m) ? m : Bot;
     49 }
     50 
     51 static void
     52 update(int t, int m, Fn *fn)
     53 {
     54 	Tmp *tmp;
     55 	uint u;
     56 
     57 	m = latmerge(val[t], m);
     58 	if (m != val[t]) {
     59 		tmp = &fn->tmp[t];
     60 		for (u=0; u<tmp->nuse; u++) {
     61 			vgrow(&usewrk, ++nuse);
     62 			usewrk[nuse-1] = &tmp->use[u];
     63 		}
     64 		val[t] = m;
     65 	}
     66 }
     67 
     68 static int
     69 deadedge(int s, int d)
     70 {
     71 	Edge *e;
     72 
     73 	e = edge[s];
     74 	if (e[0].dest == d && !e[0].dead)
     75 		return 0;
     76 	if (e[1].dest == d && !e[1].dead)
     77 		return 0;
     78 	return 1;
     79 }
     80 
     81 static void
     82 visitphi(Phi *p, int n, Fn *fn)
     83 {
     84 	int v;
     85 	uint a;
     86 
     87 	v = Top;
     88 	for (a=0; a<p->narg; a++)
     89 		if (!deadedge(p->blk[a]->id, n))
     90 			v = latmerge(v, latval(p->arg[a]));
     91 	update(p->to.val, v, fn);
     92 }
     93 
     94 static int opfold(int, int, Con *, Con *, Fn *);
     95 
     96 static void
     97 visitins(Ins *i, Fn *fn)
     98 {
     99 	int v, l, r;
    100 
    101 	if (rtype(i->to) != RTmp)
    102 		return;
    103 	if (optab[i->op].canfold) {
    104 		l = latval(i->arg[0]);
    105 		if (!req(i->arg[1], R))
    106 			r = latval(i->arg[1]);
    107 		else
    108 			r = CON_Z.val;
    109 		if (l == Bot || r == Bot)
    110 			v = Bot;
    111 		else if (l == Top || r == Top)
    112 			v = Top;
    113 		else
    114 			v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
    115 	} else
    116 		v = Bot;
    117 	/* fprintf(stderr, "\nvisiting %s (%p)", optab[i->op].name, (void *)i); */
    118 	update(i->to.val, v, fn);
    119 }
    120 
    121 static void
    122 visitjmp(Blk *b, int n, Fn *fn)
    123 {
    124 	int l;
    125 
    126 	switch (b->jmp.type) {
    127 	case Jjnz:
    128 		l = latval(b->jmp.arg);
    129 		assert(l != Top && "ssa invariant broken");
    130 		if (l == Bot) {
    131 			edge[n][1].work = flowrk;
    132 			edge[n][0].work = &edge[n][1];
    133 			flowrk = &edge[n][0];
    134 		}
    135 		else if (iscon(&fn->con[l], 0, 0)) {
    136 			assert(edge[n][0].dead);
    137 			edge[n][1].work = flowrk;
    138 			flowrk = &edge[n][1];
    139 		}
    140 		else {
    141 			assert(edge[n][1].dead);
    142 			edge[n][0].work = flowrk;
    143 			flowrk = &edge[n][0];
    144 		}
    145 		break;
    146 	case Jjmp:
    147 		edge[n][0].work = flowrk;
    148 		flowrk = &edge[n][0];
    149 		break;
    150 	default:
    151 		if (isret(b->jmp.type))
    152 			break;
    153 		die("unreachable");
    154 	}
    155 }
    156 
    157 static void
    158 initedge(Edge *e, Blk *s)
    159 {
    160 	if (s)
    161 		e->dest = s->id;
    162 	else
    163 		e->dest = -1;
    164 	e->dead = 1;
    165 	e->work = 0;
    166 }
    167 
    168 static int
    169 renref(Ref *r)
    170 {
    171 	int l;
    172 
    173 	if (rtype(*r) == RTmp)
    174 		if ((l=val[r->val]) != Bot) {
    175 			assert(l != Top && "ssa invariant broken");
    176 			*r = CON(l);
    177 			return 1;
    178 		}
    179 	return 0;
    180 }
    181 
    182 /* require rpo, use, pred */
    183 void
    184 fold(Fn *fn)
    185 {
    186 	Edge *e, start;
    187 	Use *u;
    188 	Blk *b, **pb;
    189 	Phi *p, **pp;
    190 	Ins *i;
    191 	int t, d;
    192 	uint n, a;
    193 
    194 	val = emalloc(fn->ntmp * sizeof val[0]);
    195 	edge = emalloc(fn->nblk * sizeof edge[0]);
    196 	usewrk = vnew(0, sizeof usewrk[0], Pheap);
    197 
    198 	for (t=0; t<fn->ntmp; t++)
    199 		val[t] = Top;
    200 	for (n=0; n<fn->nblk; n++) {
    201 		b = fn->rpo[n];
    202 		b->visit = 0;
    203 		initedge(&edge[n][0], b->s1);
    204 		initedge(&edge[n][1], b->s2);
    205 	}
    206 	initedge(&start, fn->start);
    207 	flowrk = &start;
    208 	nuse = 0;
    209 
    210 	/* 1. find out constants and dead cfg edges */
    211 	for (;;) {
    212 		e = flowrk;
    213 		if (e) {
    214 			flowrk = e->work;
    215 			e->work = 0;
    216 			if (e->dest == -1 || !e->dead)
    217 				continue;
    218 			e->dead = 0;
    219 			n = e->dest;
    220 			b = fn->rpo[n];
    221 			for (p=b->phi; p; p=p->link)
    222 				visitphi(p, n, fn);
    223 			if (b->visit == 0) {
    224 				for (i=b->ins; i<&b->ins[b->nins]; i++)
    225 					visitins(i, fn);
    226 				visitjmp(b, n, fn);
    227 			}
    228 			b->visit++;
    229 			assert(b->jmp.type != Jjmp
    230 				|| !edge[n][0].dead
    231 				|| flowrk == &edge[n][0]);
    232 		}
    233 		else if (nuse) {
    234 			u = usewrk[--nuse];
    235 			n = u->bid;
    236 			b = fn->rpo[n];
    237 			if (b->visit == 0)
    238 				continue;
    239 			switch (u->type) {
    240 			case UPhi:
    241 				visitphi(u->u.phi, u->bid, fn);
    242 				break;
    243 			case UIns:
    244 				visitins(u->u.ins, fn);
    245 				break;
    246 			case UJmp:
    247 				visitjmp(b, n, fn);
    248 				break;
    249 			default:
    250 				die("unreachable");
    251 			}
    252 		}
    253 		else
    254 			break;
    255 	}
    256 
    257 	if (debug['F']) {
    258 		fprintf(stderr, "\n> SCCP findings:");
    259 		for (t=Tmp0; t<fn->ntmp; t++) {
    260 			if (val[t] == Bot)
    261 				continue;
    262 			fprintf(stderr, "\n%10s: ", fn->tmp[t].name);
    263 			if (val[t] == Top)
    264 				fprintf(stderr, "Top");
    265 			else
    266 				printref(CON(val[t]), fn, stderr);
    267 		}
    268 		fprintf(stderr, "\n dead code: ");
    269 	}
    270 
    271 	/* 2. trim dead code, replace constants */
    272 	d = 0;
    273 	for (pb=&fn->start; (b=*pb);) {
    274 		if (b->visit == 0) {
    275 			d = 1;
    276 			if (debug['F'])
    277 				fprintf(stderr, "%s ", b->name);
    278 			edgedel(b, &b->s1);
    279 			edgedel(b, &b->s2);
    280 			*pb = b->link;
    281 			continue;
    282 		}
    283 		for (pp=&b->phi; (p=*pp);)
    284 			if (val[p->to.val] != Bot)
    285 				*pp = p->link;
    286 			else {
    287 				for (a=0; a<p->narg; a++)
    288 					if (!deadedge(p->blk[a]->id, b->id))
    289 						renref(&p->arg[a]);
    290 				pp = &p->link;
    291 			}
    292 		for (i=b->ins; i<&b->ins[b->nins]; i++)
    293 			if (renref(&i->to))
    294 				*i = (Ins){.op = Onop};
    295 			else
    296 				for (n=0; n<2; n++)
    297 					renref(&i->arg[n]);
    298 		renref(&b->jmp.arg);
    299 		if (b->jmp.type == Jjnz && rtype(b->jmp.arg) == RCon) {
    300 				if (iscon(&fn->con[b->jmp.arg.val], 0, 0)) {
    301 					edgedel(b, &b->s1);
    302 					b->s1 = b->s2;
    303 					b->s2 = 0;
    304 				} else
    305 					edgedel(b, &b->s2);
    306 				b->jmp.type = Jjmp;
    307 				b->jmp.arg = R;
    308 		}
    309 		pb = &b->link;
    310 	}
    311 
    312 	if (debug['F']) {
    313 		if (!d)
    314 			fprintf(stderr, "(none)");
    315 		fprintf(stderr, "\n\n> After constant folding:\n");
    316 		printfn(fn, stderr);
    317 	}
    318 
    319 	free(val);
    320 	free(edge);
    321 	vfree(usewrk);
    322 }
    323 
    324 /* boring folding code */
    325 
    326 static int
    327 foldint(Con *res, int op, int w, Con *cl, Con *cr)
    328 {
    329 	union {
    330 		int64_t s;
    331 		uint64_t u;
    332 		float fs;
    333 		double fd;
    334 	} l, r;
    335 	uint64_t x;
    336 	uint32_t lab;
    337 	int typ;
    338 
    339 	typ = CBits;
    340 	lab = 0;
    341 	l.s = cl->bits.i;
    342 	r.s = cr->bits.i;
    343 	if (op == Oadd) {
    344 		if (cl->type == CAddr) {
    345 			if (cr->type == CAddr)
    346 				return 1;
    347 			lab = cl->label;
    348 			typ = CAddr;
    349 		}
    350 		else if (cr->type == CAddr) {
    351 			lab = cr->label;
    352 			typ = CAddr;
    353 		}
    354 	}
    355 	else if (op == Osub) {
    356 		if (cl->type == CAddr) {
    357 			if (cr->type != CAddr) {
    358 				lab = cl->label;
    359 				typ = CAddr;
    360 			} else if (cl->label != cr->label)
    361 				return 1;
    362 		}
    363 		else if (cr->type == CAddr)
    364 			return 1;
    365 	}
    366 	else if (cl->type == CAddr || cr->type == CAddr)
    367 		return 1;
    368 	if (op == Odiv || op == Orem || op == Oudiv || op == Ourem) {
    369 		if (iscon(cr, w, 0))
    370 			return 1;
    371 		if (op == Odiv || op == Orem) {
    372 			x = w ? INT64_MIN : INT32_MIN;
    373 			if (iscon(cr, w, -1))
    374 			if (iscon(cl, w, x))
    375 				return 1;
    376 		}
    377 	}
    378 	switch (op) {
    379 	case Oadd:  x = l.u + r.u; break;
    380 	case Osub:  x = l.u - r.u; break;
    381 	case Oneg:  x = -l.u; break;
    382 	case Odiv:  x = w ? l.s / r.s : (int32_t)l.s / (int32_t)r.s; break;
    383 	case Orem:  x = w ? l.s % r.s : (int32_t)l.s % (int32_t)r.s; break;
    384 	case Oudiv: x = w ? l.u / r.u : (uint32_t)l.u / (uint32_t)r.u; break;
    385 	case Ourem: x = w ? l.u % r.u : (uint32_t)l.u % (uint32_t)r.u; break;
    386 	case Omul:  x = l.u * r.u; break;
    387 	case Oand:  x = l.u & r.u; break;
    388 	case Oor:   x = l.u | r.u; break;
    389 	case Oxor:  x = l.u ^ r.u; break;
    390 	case Osar:  x = (w ? l.s : (int32_t)l.s) >> (r.u & (31|w<<5)); break;
    391 	case Oshr:  x = (w ? l.u : (uint32_t)l.u) >> (r.u & (31|w<<5)); break;
    392 	case Oshl:  x = l.u << (r.u & (31|w<<5)); break;
    393 	case Oextsb: x = (int8_t)l.u;   break;
    394 	case Oextub: x = (uint8_t)l.u;  break;
    395 	case Oextsh: x = (int16_t)l.u;  break;
    396 	case Oextuh: x = (uint16_t)l.u; break;
    397 	case Oextsw: x = (int32_t)l.u;  break;
    398 	case Oextuw: x = (uint32_t)l.u; break;
    399 	case Ostosi: x = w ? (int64_t)cl->bits.s : (int32_t)cl->bits.s; break;
    400 	case Ostoui: x = w ? (uint64_t)cl->bits.s : (uint32_t)cl->bits.s; break;
    401 	case Odtosi: x = w ? (int64_t)cl->bits.d : (int32_t)cl->bits.d; break;
    402 	case Odtoui: x = w ? (uint64_t)cl->bits.d : (uint32_t)cl->bits.d; break;
    403 	case Ocast:
    404 		x = l.u;
    405 		if (cl->type == CAddr) {
    406 			lab = cl->label;
    407 			typ = CAddr;
    408 		}
    409 		break;
    410 	default:
    411 		if (Ocmpw <= op && op <= Ocmpl1) {
    412 			if (op <= Ocmpw1) {
    413 				l.u = (int32_t)l.u;
    414 				r.u = (int32_t)r.u;
    415 			} else
    416 				op -= Ocmpl - Ocmpw;
    417 			switch (op - Ocmpw) {
    418 			case Ciule: x = l.u <= r.u; break;
    419 			case Ciult: x = l.u < r.u;  break;
    420 			case Cisle: x = l.s <= r.s; break;
    421 			case Cislt: x = l.s < r.s;  break;
    422 			case Cisgt: x = l.s > r.s;  break;
    423 			case Cisge: x = l.s >= r.s; break;
    424 			case Ciugt: x = l.u > r.u;  break;
    425 			case Ciuge: x = l.u >= r.u; break;
    426 			case Cieq:  x = l.u == r.u; break;
    427 			case Cine:  x = l.u != r.u; break;
    428 			default: die("unreachable");
    429 			}
    430 		}
    431 		else if (Ocmps <= op && op <= Ocmps1) {
    432 			switch (op - Ocmps) {
    433 			case Cfle: x = l.fs <= r.fs; break;
    434 			case Cflt: x = l.fs < r.fs;  break;
    435 			case Cfgt: x = l.fs > r.fs;  break;
    436 			case Cfge: x = l.fs >= r.fs; break;
    437 			case Cfne: x = l.fs != r.fs; break;
    438 			case Cfeq: x = l.fs == r.fs; break;
    439 			case Cfo: x = l.fs < r.fs || l.fs >= r.fs; break;
    440 			case Cfuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
    441 			default: die("unreachable");
    442 			}
    443 		}
    444 		else if (Ocmpd <= op && op <= Ocmpd1) {
    445 			switch (op - Ocmpd) {
    446 			case Cfle: x = l.fd <= r.fd; break;
    447 			case Cflt: x = l.fd < r.fd;  break;
    448 			case Cfgt: x = l.fd > r.fd;  break;
    449 			case Cfge: x = l.fd >= r.fd; break;
    450 			case Cfne: x = l.fd != r.fd; break;
    451 			case Cfeq: x = l.fd == r.fd; break;
    452 			case Cfo: x = l.fd < r.fd || l.fd >= r.fd; break;
    453 			case Cfuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
    454 			default: die("unreachable");
    455 			}
    456 		}
    457 		else
    458 			die("unreachable");
    459 	}
    460 	*res = (Con){.type=typ, .label=lab, .bits={.i=x}};
    461 	return 0;
    462 }
    463 
    464 static void
    465 foldflt(Con *res, int op, int w, Con *cl, Con *cr)
    466 {
    467 	float xs, ls, rs;
    468 	double xd, ld, rd;
    469 
    470 	if (cl->type != CBits || cr->type != CBits)
    471 		err("invalid address operand for '%s'", optab[op].name);
    472 	*res = (Con){.type = CBits};
    473 	memset(&res->bits, 0, sizeof(res->bits));
    474 	if (w)  {
    475 		ld = cl->bits.d;
    476 		rd = cr->bits.d;
    477 		switch (op) {
    478 		case Oadd: xd = ld + rd; break;
    479 		case Osub: xd = ld - rd; break;
    480 		case Oneg: xd = -ld; break;
    481 		case Odiv: xd = ld / rd; break;
    482 		case Omul: xd = ld * rd; break;
    483 		case Oswtof: xd = (int32_t)cl->bits.i; break;
    484 		case Ouwtof: xd = (uint32_t)cl->bits.i; break;
    485 		case Osltof: xd = (int64_t)cl->bits.i; break;
    486 		case Oultof: xd = (uint64_t)cl->bits.i; break;
    487 		case Oexts: xd = cl->bits.s; break;
    488 		case Ocast: xd = ld; break;
    489 		default: die("unreachable");
    490 		}
    491 		res->bits.d = xd;
    492 		res->flt = 2;
    493 	} else {
    494 		ls = cl->bits.s;
    495 		rs = cr->bits.s;
    496 		switch (op) {
    497 		case Oadd: xs = ls + rs; break;
    498 		case Osub: xs = ls - rs; break;
    499 		case Oneg: xs = -ls; break;
    500 		case Odiv: xs = ls / rs; break;
    501 		case Omul: xs = ls * rs; break;
    502 		case Oswtof: xs = (int32_t)cl->bits.i; break;
    503 		case Ouwtof: xs = (uint32_t)cl->bits.i; break;
    504 		case Osltof: xs = (int64_t)cl->bits.i; break;
    505 		case Oultof: xs = (uint64_t)cl->bits.i; break;
    506 		case Otruncd: xs = cl->bits.d; break;
    507 		case Ocast: xs = ls; break;
    508 		default: die("unreachable");
    509 		}
    510 		res->bits.s = xs;
    511 		res->flt = 1;
    512 	}
    513 }
    514 
    515 static int
    516 opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
    517 {
    518 	Ref r;
    519 	Con c;
    520 
    521 	if (cls == Kw || cls == Kl) {
    522 		if (foldint(&c, op, cls == Kl, cl, cr))
    523 			return Bot;
    524 	} else
    525 		foldflt(&c, op, cls == Kd, cl, cr);
    526 	r = newcon(&c, fn);
    527 	assert(!(cls == Ks || cls == Kd) || c.flt);
    528 	return r.val;
    529 }