1
0
mirror of https://github.com/janet-lang/janet synced 2025-04-08 00:06:38 +00:00

Add janet_sysir_scalarize

Makes it easier to add simpler backends without needing to completely
handle vectorization.
This commit is contained in:
Calvin Rose 2025-02-24 19:12:17 -06:00
parent 768c9b23e1
commit 9a1cd6fdd9
2 changed files with 239 additions and 33 deletions

View File

@ -17,3 +17,7 @@
(def ctx (sysir/context))
(sysir/asm ctx ir-asm)
(print (sysir/to-c ctx))
(printf "%.99M" (sysir/to-ir ctx))
(print (sysir/scalarize ctx))
(printf "%.99M" (sysir/to-ir ctx))
(print (sysir/to-c ctx))

View File

@ -103,9 +103,9 @@ const char *janet_sysop_names[] = {
"type-struct", /* JANET_SYSOP_TYPE_STRUCT */
"type-bind", /* JANET_SYSOP_TYPE_BIND */
"arg", /* JANET_SYSOP_ARG */
"field-getp", /* JANET_SYSOP_FIELD_GETP */
"array-getp", /* JANET_SYSOP_ARRAY_GETP */
"array-pgetp", /* JANET_SYSOP_ARRAY_PGETP */
"fgetp", /* JANET_SYSOP_FIELD_GETP */
"agetp", /* JANET_SYSOP_ARRAY_GETP */
"apgetp", /* JANET_SYSOP_ARRAY_PGETP */
"type-pointer", /* JANET_SYSOP_TYPE_POINTER */
"type-array", /* JANET_SYSOP_TYPE_ARRAY */
"type-union", /* JANET_SYSOP_TYPE_UNION */
@ -198,9 +198,9 @@ static JanetString *table_to_string_array(JanetTable *strings_to_indices, int32_
return NULL;
}
janet_assert(count > 0, "bad count");
JanetString *strings = janet_malloc(count * sizeof(JanetString));
JanetString *strings = NULL;
for (int32_t i = 0; i < count; i++) {
strings[i] = NULL;
janet_v_push(strings, NULL);
}
for (int32_t i = 0; i < strings_to_indices->capacity; i++) {
JanetKV *kv = strings_to_indices->data + i;
@ -307,24 +307,29 @@ static uint32_t instr_read_type_operand(Janet x, JanetSysIR *ir, ReadOpMode rmod
return operand;
}
static uint32_t janet_sys_makeconst(JanetSysIR *sysir, uint32_t type, Janet x) {
JanetSysConstant jsc;
jsc.type = type;
jsc.value = x;
for (int32_t i = 0; i < janet_v_count(sysir->constants); i++) {
if (sysir->constants[i].type != jsc.type) continue;
if (!janet_equals(sysir->constants[i].value, x)) continue;
/* Found a constant */
return JANET_SYS_CONSTANT_PREFIX + i;
}
uint32_t index = (uint32_t) janet_v_count(sysir->constants);
janet_v_push(sysir->constants, jsc);
sysir->constant_count++;
return JANET_SYS_CONSTANT_PREFIX + index;
}
static uint32_t instr_read_operand_or_const(Janet x, JanetSysIR *ir) {
if (janet_checktype(x, JANET_TUPLE)) {
JanetSysConstant jsc;
const Janet *tup = janet_unwrap_tuple(x);
if (janet_tuple_length(tup) != 2) janet_panicf("expected constant wrapped in tuple, got %p", x);
Janet c = tup[1];
jsc.type = instr_read_type_operand(tup[0], ir, READ_TYPE_REFERENCE);
jsc.value = c;
/* TODO - Use a hash table or something better than linear lookup */
for (int32_t i = 0; i < janet_v_count(ir->constants); i++) {
if (ir->constants[i].type != jsc.type) continue;
if (!janet_equals(ir->constants[i].value, c)) continue;
/* Found a constant */
return JANET_SYS_CONSTANT_PREFIX + i;
}
uint32_t index = (uint32_t) janet_v_count(ir->constants);
janet_v_push(ir->constants, jsc);
return JANET_SYS_CONSTANT_PREFIX + index;
uint32_t t = instr_read_type_operand(tup[0], ir, READ_TYPE_REFERENCE);
return janet_sys_makeconst(ir, t, c);
}
return instr_read_operand(x, ir);
}
@ -665,7 +670,6 @@ static void janet_sysir_init_instructions(JanetSysIR *out, JanetView instruction
/* Build constants */
out->constant_count = janet_v_count(out->constants);
out->constants = janet_v_flatten(out->constants);
}
/* Get a type index given an operand */
@ -724,14 +728,19 @@ static void tcheck_redef(JanetSysIR *ir, uint32_t typeid) {
static void janet_sysir_init_types(JanetSysIR *ir) {
JanetSysIRLinkage *linkage = ir->linkage;
JanetSysTypeField *fields = NULL;
JanetSysTypeInfo *type_defs = janet_realloc(linkage->type_defs, sizeof(JanetSysTypeInfo) * (linkage->type_def_count));
uint32_t field_offset = linkage->field_def_count;
uint32_t *types = janet_malloc(sizeof(uint32_t) * ir->register_count);
linkage->type_defs = type_defs;
ir->types = types;
for (uint32_t i = 0; i < ir->register_count; i++) {
ir->types[i] = 0;
JanetSysTypeInfo td;
memset(&td, 0, sizeof(td));
for (uint32_t i = 0; i < linkage->type_def_count; i++) {
janet_v_push(linkage->type_defs, td);
}
JanetSysTypeInfo *type_defs = linkage->type_defs;
uint32_t field_offset = linkage->field_def_count;
uint32_t *types = NULL;
linkage->type_defs = type_defs;
for (uint32_t i = 0; i < ir->register_count; i++) {
janet_v_push(types, 0);
}
ir->types = types;
for (uint32_t i = linkage->old_type_def_count; i < linkage->type_def_count; i++) {
type_defs[i].prim = JANET_PRIM_UNKNOWN;
}
@ -795,7 +804,7 @@ static void janet_sysir_init_types(JanetSysIR *ir) {
if (janet_v_count(fields)) {
uint32_t new_field_count = field_offset + janet_v_count(fields);
linkage->field_defs = janet_realloc(linkage->field_defs, sizeof(JanetSysTypeField) * new_field_count);
memcpy(linkage->field_defs + field_offset, fields, janet_v_count(fields) * sizeof(JanetSysTypeField));
safe_memcpy(linkage->field_defs + field_offset, fields, janet_v_count(fields) * sizeof(JanetSysTypeField));
linkage->field_def_count = new_field_count;
janet_v_free(fields);
}
@ -1332,7 +1341,7 @@ static void janet_sys_ir_init(JanetSysIR *out, JanetView instructions, JanetSysI
/* Patch up name mapping arrays */
/* TODO - make more efficient, don't rebuild from scratch every time */
if (linkage->type_names) janet_free((void *) linkage->type_names);
if (linkage->type_names) janet_v_free((void *) linkage->type_names);
linkage->type_names = table_to_string_array(linkage->type_name_lookup, linkage->type_def_count);
ir.register_names = table_to_string_array(ir.register_name_lookup, ir.register_count);
@ -1346,6 +1355,189 @@ static void janet_sys_ir_init(JanetSysIR *out, JanetView instructions, JanetSysI
janet_array_push(linkage->ir_ordered, janet_wrap_abstract(out));
}
/*
* Passes
*/
static JanetSysInstruction makethree(JanetSysInstruction source, JanetSysOp opcode, uint32_t dest, uint32_t lhs, uint32_t rhs) {
source.opcode = opcode;
source.three.dest = dest;
source.three.lhs = lhs;
source.three.rhs = rhs;
return source;
}
static JanetSysInstruction maketwo(JanetSysInstruction source, JanetSysOp opcode, uint32_t dest, uint32_t src) {
source.opcode = opcode;
source.two.dest = dest;
source.two.src = src;
return source;
}
static JanetSysInstruction makejmp(JanetSysInstruction source, JanetSysOp opcode, uint32_t to) {
source.opcode = opcode;
source.jump.to = to;
return source;
}
static JanetSysInstruction makebranch(JanetSysInstruction source, JanetSysOp opcode, uint32_t cond, uint32_t labelid) {
source.opcode = opcode;
source.branch.cond = cond;
source.branch.to = labelid;
return source;
}
static JanetSysInstruction makelabel(JanetSysInstruction source, JanetSysOp opcode, uint32_t id) {
source.opcode = opcode;
source.label.id = id;
return source;
}
static JanetSysInstruction makebind(JanetSysInstruction source, JanetSysOp opcode, uint32_t reg, uint32_t type) {
source.opcode = opcode;
source.type_bind.dest = reg;
source.type_bind.type = type;
return source;
}
static uint32_t janet_sysir_getreg(JanetSysIR *sysir, uint32_t type) {
uint32_t ret = sysir->register_count++;
janet_v_push(sysir->types, type);
return ret;
}
/* Find primitive types in the current linkage to avoid creating tons
* of copies of duplicate types. */
static uint32_t janet_sysir_findprim(JanetSysIRLinkage *linkage, JanetPrim prim, const char *type_name) {
for (uint32_t i = 0; i < linkage->type_def_count; i++) {
if (linkage->type_defs[i].prim == prim) {
return i;
}
}
/* Add new type */
JanetSysTypeInfo td;
memset(&td, 0, sizeof(td));
td.prim = prim;
janet_v_push(linkage->type_defs, td);
janet_table_put(linkage->type_name_lookup,
janet_csymbolv(type_name),
janet_wrap_number(linkage->type_def_count));
janet_v_push(linkage->type_names, janet_csymbol(type_name));
return linkage->type_def_count++;
}
/* Get a type that is a pointer to another type */
static uint32_t janet_sysir_findpointer(JanetSysIRLinkage *linkage, uint32_t to, const char *type_name) {
for (uint32_t i = 0; i < linkage->type_def_count; i++) {
if (linkage->type_defs[i].prim == JANET_PRIM_POINTER) {
if (linkage->type_defs[i].pointer.type == to) {
return i;
}
}
}
/* Add new type */
JanetSysTypeInfo td;
memset(&td, 0, sizeof(td));
td.prim = JANET_PRIM_POINTER;
td.pointer.type = to;
janet_v_push(linkage->type_defs, td);
janet_table_put(linkage->type_name_lookup,
janet_csymbolv(type_name),
janet_wrap_number(linkage->type_def_count));
janet_v_push(linkage->type_names, janet_csymbol(type_name));
return linkage->type_def_count++;
}
/* Unwrap vectorized binops to scalars in one pass to make certain lowering easier. */
static void janet_sysir_scalarize(JanetSysIRLinkage *linkage) {
uint32_t index_type = janet_sysir_findprim(linkage, JANET_PRIM_U32, "U32Index");
uint32_t boolean_type = janet_sysir_findprim(linkage, JANET_PRIM_BOOLEAN, "Boolean");
for (int32_t j = 0; j < linkage->ir_ordered->count; j++) {
JanetSysIR *sysir = janet_unwrap_abstract(linkage->ir_ordered->data[j]);
for (uint32_t i = 0; i < sysir->instruction_count; i++) {
JanetSysInstruction instruction = sysir->instructions[i];
sysir->error_ctx = janet_cstringv(janet_sysop_names[instruction.opcode]);
switch (instruction.opcode) {
default:
break;
case JANET_SYSOP_ADD:
case JANET_SYSOP_SUBTRACT:
case JANET_SYSOP_MULTIPLY:
case JANET_SYSOP_DIVIDE:
case JANET_SYSOP_BAND:
case JANET_SYSOP_BOR:
case JANET_SYSOP_BXOR:
case JANET_SYSOP_GT:
case JANET_SYSOP_LT:
case JANET_SYSOP_EQ:
case JANET_SYSOP_NEQ:
case JANET_SYSOP_GTE:
case JANET_SYSOP_LTE:
case JANET_SYSOP_SHL:
case JANET_SYSOP_SHR:
;
{
uint32_t dest_type = janet_sys_optype(sysir, instruction.three.dest);
uint32_t test_type = dest_type;
if (linkage->type_defs[dest_type].prim == JANET_PRIM_POINTER) {
test_type = linkage->type_defs[dest_type].pointer.type;
}
if (linkage->type_defs[test_type].prim != JANET_PRIM_ARRAY) {
break;
}
uint32_t pel_type = janet_sysir_findpointer(linkage, linkage->type_defs[test_type].array.type, "PointerTo"); // fixme - type name would need to be unique
uint32_t lhs_type = janet_sys_optype(sysir, instruction.three.lhs);
uint32_t rhs_type = janet_sys_optype(sysir, instruction.three.rhs);
uint32_t array_size = linkage->type_defs[dest_type].array.fixed_count;
uint32_t index_reg = janet_sysir_getreg(sysir, index_type);
uint32_t compare_reg = janet_sysir_getreg(sysir, boolean_type);
uint32_t temp_lhs = janet_sysir_getreg(sysir, pel_type);
uint32_t temp_rhs = janet_sysir_getreg(sysir, pel_type);
uint32_t temp_dest = janet_sysir_getreg(sysir, pel_type);
uint32_t loopstart_label = sysir->label_count++;
uint32_t loopend_label = sysir->label_count++;
Janet labelkw_loopstart = janet_wrap_keyword(janet_symbol_gen());
Janet labelkw_loopend = janet_wrap_keyword(janet_symbol_gen());
JanetSysOp lhs_getp = (linkage->type_defs[lhs_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP;
JanetSysOp rhs_getp = (linkage->type_defs[rhs_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP;
JanetSysOp dest_getp = (linkage->type_defs[dest_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP;
JanetSysInstruction patch[] = {
makebind(instruction, JANET_SYSOP_TYPE_BIND, index_reg, index_type),
makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_lhs, pel_type),
makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_rhs, pel_type),
makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_dest, pel_type),
makebind(instruction, JANET_SYSOP_TYPE_BIND, compare_reg, boolean_type),
maketwo(instruction, JANET_SYSOP_LOAD, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(0))),
makelabel(instruction, JANET_SYSOP_LABEL, loopstart_label),
makethree(instruction, JANET_SYSOP_GTE, compare_reg, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(array_size))),
makebranch(instruction, JANET_SYSOP_BRANCH, compare_reg, loopend_label),
makethree(instruction, lhs_getp, temp_lhs, instruction.three.lhs, index_reg),
makethree(instruction, rhs_getp, temp_rhs, instruction.three.rhs, index_reg),
makethree(instruction, dest_getp, temp_dest, instruction.three.dest, index_reg),
makethree(instruction, instruction.opcode, temp_dest, temp_lhs, temp_rhs),
makethree(instruction, JANET_SYSOP_ADD, index_reg, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(1))),
makejmp(instruction, JANET_SYSOP_JUMP, loopstart_label),
makelabel(instruction, JANET_SYSOP_LABEL, loopend_label)
};
size_t patchcount = sizeof(patch) / sizeof(patch[0]);
janet_table_put(sysir->labels, labelkw_loopstart, janet_wrap_number(loopstart_label));
janet_table_put(sysir->labels, labelkw_loopend, janet_wrap_number(loopend_label));
janet_table_put(sysir->labels, janet_wrap_number(loopstart_label), janet_wrap_number(i + 1));
janet_table_put(sysir->labels, janet_wrap_number(loopend_label), janet_wrap_number(i + patchcount - 1));
size_t remaining = (sysir->instruction_count - i - 1) * sizeof(JanetSysInstruction);
sysir->instructions = janet_realloc(sysir->instructions, (sysir->instruction_count + patchcount - 1) * sizeof(JanetSysInstruction));
if (remaining) memmove(sysir->instructions + i + patchcount, sysir->instructions + i + 1, remaining);
safe_memcpy(sysir->instructions + i, patch, sizeof(patch));
i += patchcount - 2;
sysir->instruction_count += patchcount - 1;
break;
}
}
}
}
}
/* Lowering to C */
static const char *c_prim_names[] = {
@ -1917,10 +2109,10 @@ void janet_sys_ir_lower_to_ir(JanetSysIRLinkage *linkage, JanetArray *into) {
static int sysir_gc(void *p, size_t s) {
JanetSysIR *ir = (JanetSysIR *)p;
(void) s;
janet_free(ir->constants);
janet_free(ir->types);
janet_v_free(ir->constants);
janet_v_free(ir->types);
janet_v_free(ir->register_names);
janet_free(ir->instructions);
janet_free((void *) ir->register_names);
return 0;
}
@ -1949,8 +2141,8 @@ static int sysir_context_gc(void *p, size_t s) {
JanetSysIRLinkage *linkage = (JanetSysIRLinkage *)p;
(void) s;
janet_free(linkage->field_defs);
janet_free(linkage->type_defs);
janet_free((void *) linkage->type_names);
janet_v_free(linkage->type_defs);
janet_v_free((void *) linkage->type_names);
return 0;
}
@ -2024,6 +2216,15 @@ JANET_CORE_FN(cfun_sysir_toir,
return janet_wrap_array(array);
}
JANET_CORE_FN(cfun_sysir_scalarize,
"(sysir/scalarize context)",
"Lower all vectorized instrinsics to loops of scalar operations.") {
janet_fixarity(argc, 1);
JanetSysIRLinkage *ir = janet_getabstract(argv, 0, &janet_sysir_context_type);
janet_sysir_scalarize(ir);
return argv[0];
}
JANET_CORE_FN(cfun_sysir_tox64,
"(sysir/to-x64 context &opt buffer target)",
"Lower IR to x64 machine code.") {
@ -2052,6 +2253,7 @@ void janet_lib_sysir(JanetTable *env) {
JanetRegExt cfuns[] = {
JANET_CORE_REG("sysir/context", cfun_sysir_context),
JANET_CORE_REG("sysir/asm", cfun_sysir_asm),
JANET_CORE_REG("sysir/scalarize", cfun_sysir_scalarize),
JANET_CORE_REG("sysir/to-c", cfun_sysir_toc),
JANET_CORE_REG("sysir/to-ir", cfun_sysir_toir),
JANET_CORE_REG("sysir/to-x64", cfun_sysir_tox64),