fold.c (11332B)
1 #include "all.h" 2 3 enum { 4 Bot = -1, /* lattice bottom */ 5 Top = 0, /* lattice top (matches UNDEF) */ 6 }; 7 8 typedef struct Edge Edge; 9 10 struct Edge { 11 int dest; 12 int dead; 13 Edge *work; 14 }; 15 16 static int *val; 17 static Edge *flowrk, (*edge)[2]; 18 static Use **usewrk; 19 static uint nuse; 20 21 static int 22 iscon(Con *c, int w, uint64_t k) 23 { 24 if (c->type != CBits) 25 return 0; 26 if (w) 27 return (uint64_t)c->bits.i == k; 28 else 29 return (uint32_t)c->bits.i == (uint32_t)k; 30 } 31 32 static int 33 latval(Ref r) 34 { 35 switch (rtype(r)) { 36 case RTmp: 37 return val[r.val]; 38 case RCon: 39 return r.val; 40 default: 41 die("unreachable"); 42 } 43 } 44 45 static int 46 latmerge(int v, int m) 47 { 48 return m == Top ? v : (v == Top || v == m) ? m : Bot; 49 } 50 51 static void 52 update(int t, int m, Fn *fn) 53 { 54 Tmp *tmp; 55 uint u; 56 57 m = latmerge(val[t], m); 58 if (m != val[t]) { 59 tmp = &fn->tmp[t]; 60 for (u=0; u<tmp->nuse; u++) { 61 vgrow(&usewrk, ++nuse); 62 usewrk[nuse-1] = &tmp->use[u]; 63 } 64 val[t] = m; 65 } 66 } 67 68 static int 69 deadedge(int s, int d) 70 { 71 Edge *e; 72 73 e = edge[s]; 74 if (e[0].dest == d && !e[0].dead) 75 return 0; 76 if (e[1].dest == d && !e[1].dead) 77 return 0; 78 return 1; 79 } 80 81 static void 82 visitphi(Phi *p, int n, Fn *fn) 83 { 84 int v; 85 uint a; 86 87 v = Top; 88 for (a=0; a<p->narg; a++) 89 if (!deadedge(p->blk[a]->id, n)) 90 v = latmerge(v, latval(p->arg[a])); 91 update(p->to.val, v, fn); 92 } 93 94 static int opfold(int, int, Con *, Con *, Fn *); 95 96 static void 97 visitins(Ins *i, Fn *fn) 98 { 99 int v, l, r; 100 101 if (rtype(i->to) != RTmp) 102 return; 103 if (optab[i->op].canfold) { 104 l = latval(i->arg[0]); 105 if (!req(i->arg[1], R)) 106 r = latval(i->arg[1]); 107 else 108 r = CON_Z.val; 109 if (l == Bot || r == Bot) 110 v = Bot; 111 else if (l == Top || r == Top) 112 v = Top; 113 else 114 v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn); 115 } else 116 v = Bot; 117 /* fprintf(stderr, "\nvisiting %s (%p)", optab[i->op].name, (void *)i); */ 118 update(i->to.val, v, fn); 119 } 120 121 static void 122 visitjmp(Blk *b, int n, Fn *fn) 123 { 124 int l; 125 126 switch (b->jmp.type) { 127 case Jjnz: 128 l = latval(b->jmp.arg); 129 if (l == Bot) { 130 edge[n][1].work = flowrk; 131 edge[n][0].work = &edge[n][1]; 132 flowrk = &edge[n][0]; 133 } 134 else if (iscon(&fn->con[l], 0, 0)) { 135 assert(edge[n][0].dead); 136 edge[n][1].work = flowrk; 137 flowrk = &edge[n][1]; 138 } 139 else { 140 assert(edge[n][1].dead); 141 edge[n][0].work = flowrk; 142 flowrk = &edge[n][0]; 143 } 144 break; 145 case Jjmp: 146 edge[n][0].work = flowrk; 147 flowrk = &edge[n][0]; 148 break; 149 case Jhlt: 150 break; 151 default: 152 if (isret(b->jmp.type)) 153 break; 154 die("unreachable"); 155 } 156 } 157 158 static void 159 initedge(Edge *e, Blk *s) 160 { 161 if (s) 162 e->dest = s->id; 163 else 164 e->dest = -1; 165 e->dead = 1; 166 e->work = 0; 167 } 168 169 static int 170 renref(Ref *r) 171 { 172 int l; 173 174 if (rtype(*r) == RTmp) 175 if ((l=val[r->val]) != Bot) { 176 *r = CON(l); 177 return 1; 178 } 179 return 0; 180 } 181 182 /* require rpo, use, pred */ 183 void 184 fold(Fn *fn) 185 { 186 Edge *e, start; 187 Use *u; 188 Blk *b, **pb; 189 Phi *p, **pp; 190 Ins *i; 191 int t, d; 192 uint n, a; 193 194 val = emalloc(fn->ntmp * sizeof val[0]); 195 edge = emalloc(fn->nblk * sizeof edge[0]); 196 usewrk = vnew(0, sizeof usewrk[0], PHeap); 197 198 for (t=0; t<fn->ntmp; t++) 199 val[t] = Top; 200 for (n=0; n<fn->nblk; n++) { 201 b = fn->rpo[n]; 202 b->visit = 0; 203 initedge(&edge[n][0], b->s1); 204 initedge(&edge[n][1], b->s2); 205 } 206 initedge(&start, fn->start); 207 flowrk = &start; 208 nuse = 0; 209 210 /* 1. find out constants and dead cfg edges */ 211 for (;;) { 212 e = flowrk; 213 if (e) { 214 flowrk = e->work; 215 e->work = 0; 216 if (e->dest == -1 || !e->dead) 217 continue; 218 e->dead = 0; 219 n = e->dest; 220 b = fn->rpo[n]; 221 for (p=b->phi; p; p=p->link) 222 visitphi(p, n, fn); 223 if (b->visit == 0) { 224 for (i=b->ins; i<&b->ins[b->nins]; i++) 225 visitins(i, fn); 226 visitjmp(b, n, fn); 227 } 228 b->visit++; 229 assert(b->jmp.type != Jjmp 230 || !edge[n][0].dead 231 || flowrk == &edge[n][0]); 232 } 233 else if (nuse) { 234 u = usewrk[--nuse]; 235 n = u->bid; 236 b = fn->rpo[n]; 237 if (b->visit == 0) 238 continue; 239 switch (u->type) { 240 case UPhi: 241 visitphi(u->u.phi, u->bid, fn); 242 break; 243 case UIns: 244 visitins(u->u.ins, fn); 245 break; 246 case UJmp: 247 visitjmp(b, n, fn); 248 break; 249 default: 250 die("unreachable"); 251 } 252 } 253 else 254 break; 255 } 256 257 if (debug['F']) { 258 fprintf(stderr, "\n> SCCP findings:"); 259 for (t=Tmp0; t<fn->ntmp; t++) { 260 if (val[t] == Bot) 261 continue; 262 fprintf(stderr, "\n%10s: ", fn->tmp[t].name); 263 if (val[t] == Top) 264 fprintf(stderr, "Top"); 265 else 266 printref(CON(val[t]), fn, stderr); 267 } 268 fprintf(stderr, "\n dead code: "); 269 } 270 271 /* 2. trim dead code, replace constants */ 272 d = 0; 273 for (pb=&fn->start; (b=*pb);) { 274 if (b->visit == 0) { 275 d = 1; 276 if (debug['F']) 277 fprintf(stderr, "%s ", b->name); 278 edgedel(b, &b->s1); 279 edgedel(b, &b->s2); 280 *pb = b->link; 281 continue; 282 } 283 for (pp=&b->phi; (p=*pp);) 284 if (val[p->to.val] != Bot) 285 *pp = p->link; 286 else { 287 for (a=0; a<p->narg; a++) 288 if (!deadedge(p->blk[a]->id, b->id)) 289 renref(&p->arg[a]); 290 pp = &p->link; 291 } 292 for (i=b->ins; i<&b->ins[b->nins]; i++) 293 if (renref(&i->to)) 294 *i = (Ins){.op = Onop}; 295 else { 296 for (n=0; n<2; n++) 297 renref(&i->arg[n]); 298 if (isstore(i->op)) 299 if (req(i->arg[0], UNDEF)) 300 *i = (Ins){.op = Onop}; 301 } 302 renref(&b->jmp.arg); 303 if (b->jmp.type == Jjnz && rtype(b->jmp.arg) == RCon) { 304 if (iscon(&fn->con[b->jmp.arg.val], 0, 0)) { 305 edgedel(b, &b->s1); 306 b->s1 = b->s2; 307 b->s2 = 0; 308 } else 309 edgedel(b, &b->s2); 310 b->jmp.type = Jjmp; 311 b->jmp.arg = R; 312 } 313 pb = &b->link; 314 } 315 316 if (debug['F']) { 317 if (!d) 318 fprintf(stderr, "(none)"); 319 fprintf(stderr, "\n\n> After constant folding:\n"); 320 printfn(fn, stderr); 321 } 322 323 free(val); 324 free(edge); 325 vfree(usewrk); 326 } 327 328 /* boring folding code */ 329 330 static int 331 foldint(Con *res, int op, int w, Con *cl, Con *cr) 332 { 333 union { 334 int64_t s; 335 uint64_t u; 336 float fs; 337 double fd; 338 } l, r; 339 uint64_t x; 340 Sym sym; 341 int typ; 342 343 memset(&sym, 0, sizeof sym); 344 typ = CBits; 345 l.s = cl->bits.i; 346 r.s = cr->bits.i; 347 if (op == Oadd) { 348 if (cl->type == CAddr) { 349 if (cr->type == CAddr) 350 return 1; 351 typ = CAddr; 352 sym = cl->sym; 353 } 354 else if (cr->type == CAddr) { 355 typ = CAddr; 356 sym = cr->sym; 357 } 358 } 359 else if (op == Osub) { 360 if (cl->type == CAddr) { 361 if (cr->type != CAddr) { 362 typ = CAddr; 363 sym = cl->sym; 364 } else if (!symeq(cl->sym, cr->sym)) 365 return 1; 366 } 367 else if (cr->type == CAddr) 368 return 1; 369 } 370 else if (cl->type == CAddr || cr->type == CAddr) 371 return 1; 372 if (op == Odiv || op == Orem || op == Oudiv || op == Ourem) { 373 if (iscon(cr, w, 0)) 374 return 1; 375 if (op == Odiv || op == Orem) { 376 x = w ? INT64_MIN : INT32_MIN; 377 if (iscon(cr, w, -1)) 378 if (iscon(cl, w, x)) 379 return 1; 380 } 381 } 382 switch (op) { 383 case Oadd: x = l.u + r.u; break; 384 case Osub: x = l.u - r.u; break; 385 case Oneg: x = -l.u; break; 386 case Odiv: x = w ? l.s / r.s : (int32_t)l.s / (int32_t)r.s; break; 387 case Orem: x = w ? l.s % r.s : (int32_t)l.s % (int32_t)r.s; break; 388 case Oudiv: x = w ? l.u / r.u : (uint32_t)l.u / (uint32_t)r.u; break; 389 case Ourem: x = w ? l.u % r.u : (uint32_t)l.u % (uint32_t)r.u; break; 390 case Omul: x = l.u * r.u; break; 391 case Oand: x = l.u & r.u; break; 392 case Oor: x = l.u | r.u; break; 393 case Oxor: x = l.u ^ r.u; break; 394 case Osar: x = (w ? l.s : (int32_t)l.s) >> (r.u & (31|w<<5)); break; 395 case Oshr: x = (w ? l.u : (uint32_t)l.u) >> (r.u & (31|w<<5)); break; 396 case Oshl: x = l.u << (r.u & (31|w<<5)); break; 397 case Oextsb: x = (int8_t)l.u; break; 398 case Oextub: x = (uint8_t)l.u; break; 399 case Oextsh: x = (int16_t)l.u; break; 400 case Oextuh: x = (uint16_t)l.u; break; 401 case Oextsw: x = (int32_t)l.u; break; 402 case Oextuw: x = (uint32_t)l.u; break; 403 case Ostosi: x = w ? (int64_t)cl->bits.s : (int32_t)cl->bits.s; break; 404 case Ostoui: x = w ? (uint64_t)cl->bits.s : (uint32_t)cl->bits.s; break; 405 case Odtosi: x = w ? (int64_t)cl->bits.d : (int32_t)cl->bits.d; break; 406 case Odtoui: x = w ? (uint64_t)cl->bits.d : (uint32_t)cl->bits.d; break; 407 case Ocast: 408 x = l.u; 409 if (cl->type == CAddr) { 410 typ = CAddr; 411 sym = cl->sym; 412 } 413 break; 414 default: 415 if (Ocmpw <= op && op <= Ocmpl1) { 416 if (op <= Ocmpw1) { 417 l.u = (int32_t)l.u; 418 r.u = (int32_t)r.u; 419 } else 420 op -= Ocmpl - Ocmpw; 421 switch (op - Ocmpw) { 422 case Ciule: x = l.u <= r.u; break; 423 case Ciult: x = l.u < r.u; break; 424 case Cisle: x = l.s <= r.s; break; 425 case Cislt: x = l.s < r.s; break; 426 case Cisgt: x = l.s > r.s; break; 427 case Cisge: x = l.s >= r.s; break; 428 case Ciugt: x = l.u > r.u; break; 429 case Ciuge: x = l.u >= r.u; break; 430 case Cieq: x = l.u == r.u; break; 431 case Cine: x = l.u != r.u; break; 432 default: die("unreachable"); 433 } 434 } 435 else if (Ocmps <= op && op <= Ocmps1) { 436 switch (op - Ocmps) { 437 case Cfle: x = l.fs <= r.fs; break; 438 case Cflt: x = l.fs < r.fs; break; 439 case Cfgt: x = l.fs > r.fs; break; 440 case Cfge: x = l.fs >= r.fs; break; 441 case Cfne: x = l.fs != r.fs; break; 442 case Cfeq: x = l.fs == r.fs; break; 443 case Cfo: x = l.fs < r.fs || l.fs >= r.fs; break; 444 case Cfuo: x = !(l.fs < r.fs || l.fs >= r.fs); break; 445 default: die("unreachable"); 446 } 447 } 448 else if (Ocmpd <= op && op <= Ocmpd1) { 449 switch (op - Ocmpd) { 450 case Cfle: x = l.fd <= r.fd; break; 451 case Cflt: x = l.fd < r.fd; break; 452 case Cfgt: x = l.fd > r.fd; break; 453 case Cfge: x = l.fd >= r.fd; break; 454 case Cfne: x = l.fd != r.fd; break; 455 case Cfeq: x = l.fd == r.fd; break; 456 case Cfo: x = l.fd < r.fd || l.fd >= r.fd; break; 457 case Cfuo: x = !(l.fd < r.fd || l.fd >= r.fd); break; 458 default: die("unreachable"); 459 } 460 } 461 else 462 die("unreachable"); 463 } 464 *res = (Con){.type=typ, .sym=sym, .bits={.i=x}}; 465 return 0; 466 } 467 468 static void 469 foldflt(Con *res, int op, int w, Con *cl, Con *cr) 470 { 471 float xs, ls, rs; 472 double xd, ld, rd; 473 474 if (cl->type != CBits || cr->type != CBits) 475 err("invalid address operand for '%s'", optab[op].name); 476 *res = (Con){.type = CBits}; 477 memset(&res->bits, 0, sizeof(res->bits)); 478 if (w) { 479 ld = cl->bits.d; 480 rd = cr->bits.d; 481 switch (op) { 482 case Oadd: xd = ld + rd; break; 483 case Osub: xd = ld - rd; break; 484 case Oneg: xd = -ld; break; 485 case Odiv: xd = ld / rd; break; 486 case Omul: xd = ld * rd; break; 487 case Oswtof: xd = (int32_t)cl->bits.i; break; 488 case Ouwtof: xd = (uint32_t)cl->bits.i; break; 489 case Osltof: xd = (int64_t)cl->bits.i; break; 490 case Oultof: xd = (uint64_t)cl->bits.i; break; 491 case Oexts: xd = cl->bits.s; break; 492 case Ocast: xd = ld; break; 493 default: die("unreachable"); 494 } 495 res->bits.d = xd; 496 res->flt = 2; 497 } else { 498 ls = cl->bits.s; 499 rs = cr->bits.s; 500 switch (op) { 501 case Oadd: xs = ls + rs; break; 502 case Osub: xs = ls - rs; break; 503 case Oneg: xs = -ls; break; 504 case Odiv: xs = ls / rs; break; 505 case Omul: xs = ls * rs; break; 506 case Oswtof: xs = (int32_t)cl->bits.i; break; 507 case Ouwtof: xs = (uint32_t)cl->bits.i; break; 508 case Osltof: xs = (int64_t)cl->bits.i; break; 509 case Oultof: xs = (uint64_t)cl->bits.i; break; 510 case Otruncd: xs = cl->bits.d; break; 511 case Ocast: xs = ls; break; 512 default: die("unreachable"); 513 } 514 res->bits.s = xs; 515 res->flt = 1; 516 } 517 } 518 519 static int 520 opfold(int op, int cls, Con *cl, Con *cr, Fn *fn) 521 { 522 Ref r; 523 Con c; 524 525 if (cls == Kw || cls == Kl) { 526 if (foldint(&c, op, cls == Kl, cl, cr)) 527 return Bot; 528 } else 529 foldflt(&c, op, cls == Kd, cl, cr); 530 if (!KWIDE(cls)) 531 c.bits.i &= 0xffffffff; 532 r = newcon(&c, fn); 533 assert(!(cls == Ks || cls == Kd) || c.flt); 534 return r.val; 535 }