From 9a1cd6fdd9e0c423d7bc69010f3ce449110f6543 Mon Sep 17 00:00:00 2001 From: Calvin Rose Date: Mon, 24 Feb 2025 19:12:17 -0600 Subject: [PATCH] Add janet_sysir_scalarize Makes it easier to add simpler backends without needing to completely handle vectorization. --- examples/sysir/arrays2.janet | 4 + src/core/sysir.c | 268 ++++++++++++++++++++++++++++++----- 2 files changed, 239 insertions(+), 33 deletions(-) diff --git a/examples/sysir/arrays2.janet b/examples/sysir/arrays2.janet index 64b7729f..41f089e4 100644 --- a/examples/sysir/arrays2.janet +++ b/examples/sysir/arrays2.janet @@ -17,3 +17,7 @@ (def ctx (sysir/context)) (sysir/asm ctx ir-asm) (print (sysir/to-c ctx)) +(printf "%.99M" (sysir/to-ir ctx)) +(print (sysir/scalarize ctx)) +(printf "%.99M" (sysir/to-ir ctx)) +(print (sysir/to-c ctx)) diff --git a/src/core/sysir.c b/src/core/sysir.c index 54c1e0ff..a65f836f 100644 --- a/src/core/sysir.c +++ b/src/core/sysir.c @@ -103,9 +103,9 @@ const char *janet_sysop_names[] = { "type-struct", /* JANET_SYSOP_TYPE_STRUCT */ "type-bind", /* JANET_SYSOP_TYPE_BIND */ "arg", /* JANET_SYSOP_ARG */ - "field-getp", /* JANET_SYSOP_FIELD_GETP */ - "array-getp", /* JANET_SYSOP_ARRAY_GETP */ - "array-pgetp", /* JANET_SYSOP_ARRAY_PGETP */ + "fgetp", /* JANET_SYSOP_FIELD_GETP */ + "agetp", /* JANET_SYSOP_ARRAY_GETP */ + "apgetp", /* JANET_SYSOP_ARRAY_PGETP */ "type-pointer", /* JANET_SYSOP_TYPE_POINTER */ "type-array", /* JANET_SYSOP_TYPE_ARRAY */ "type-union", /* JANET_SYSOP_TYPE_UNION */ @@ -198,9 +198,9 @@ static JanetString *table_to_string_array(JanetTable *strings_to_indices, int32_ return NULL; } janet_assert(count > 0, "bad count"); - JanetString *strings = janet_malloc(count * sizeof(JanetString)); + JanetString *strings = NULL; for (int32_t i = 0; i < count; i++) { - strings[i] = NULL; + janet_v_push(strings, NULL); } for (int32_t i = 0; i < strings_to_indices->capacity; i++) { JanetKV *kv = strings_to_indices->data + i; @@ -307,24 +307,29 @@ static uint32_t instr_read_type_operand(Janet x, JanetSysIR *ir, ReadOpMode rmod return operand; } +static uint32_t janet_sys_makeconst(JanetSysIR *sysir, uint32_t type, Janet x) { + JanetSysConstant jsc; + jsc.type = type; + jsc.value = x; + for (int32_t i = 0; i < janet_v_count(sysir->constants); i++) { + if (sysir->constants[i].type != jsc.type) continue; + if (!janet_equals(sysir->constants[i].value, x)) continue; + /* Found a constant */ + return JANET_SYS_CONSTANT_PREFIX + i; + } + uint32_t index = (uint32_t) janet_v_count(sysir->constants); + janet_v_push(sysir->constants, jsc); + sysir->constant_count++; + return JANET_SYS_CONSTANT_PREFIX + index; +} + static uint32_t instr_read_operand_or_const(Janet x, JanetSysIR *ir) { if (janet_checktype(x, JANET_TUPLE)) { - JanetSysConstant jsc; const Janet *tup = janet_unwrap_tuple(x); if (janet_tuple_length(tup) != 2) janet_panicf("expected constant wrapped in tuple, got %p", x); Janet c = tup[1]; - jsc.type = instr_read_type_operand(tup[0], ir, READ_TYPE_REFERENCE); - jsc.value = c; - /* TODO - Use a hash table or something better than linear lookup */ - for (int32_t i = 0; i < janet_v_count(ir->constants); i++) { - if (ir->constants[i].type != jsc.type) continue; - if (!janet_equals(ir->constants[i].value, c)) continue; - /* Found a constant */ - return JANET_SYS_CONSTANT_PREFIX + i; - } - uint32_t index = (uint32_t) janet_v_count(ir->constants); - janet_v_push(ir->constants, jsc); - return JANET_SYS_CONSTANT_PREFIX + index; + uint32_t t = instr_read_type_operand(tup[0], ir, READ_TYPE_REFERENCE); + return janet_sys_makeconst(ir, t, c); } return instr_read_operand(x, ir); } @@ -665,7 +670,6 @@ static void janet_sysir_init_instructions(JanetSysIR *out, JanetView instruction /* Build constants */ out->constant_count = janet_v_count(out->constants); - out->constants = janet_v_flatten(out->constants); } /* Get a type index given an operand */ @@ -724,14 +728,19 @@ static void tcheck_redef(JanetSysIR *ir, uint32_t typeid) { static void janet_sysir_init_types(JanetSysIR *ir) { JanetSysIRLinkage *linkage = ir->linkage; JanetSysTypeField *fields = NULL; - JanetSysTypeInfo *type_defs = janet_realloc(linkage->type_defs, sizeof(JanetSysTypeInfo) * (linkage->type_def_count)); - uint32_t field_offset = linkage->field_def_count; - uint32_t *types = janet_malloc(sizeof(uint32_t) * ir->register_count); - linkage->type_defs = type_defs; - ir->types = types; - for (uint32_t i = 0; i < ir->register_count; i++) { - ir->types[i] = 0; + JanetSysTypeInfo td; + memset(&td, 0, sizeof(td)); + for (uint32_t i = 0; i < linkage->type_def_count; i++) { + janet_v_push(linkage->type_defs, td); } + JanetSysTypeInfo *type_defs = linkage->type_defs; + uint32_t field_offset = linkage->field_def_count; + uint32_t *types = NULL; + linkage->type_defs = type_defs; + for (uint32_t i = 0; i < ir->register_count; i++) { + janet_v_push(types, 0); + } + ir->types = types; for (uint32_t i = linkage->old_type_def_count; i < linkage->type_def_count; i++) { type_defs[i].prim = JANET_PRIM_UNKNOWN; } @@ -795,7 +804,7 @@ static void janet_sysir_init_types(JanetSysIR *ir) { if (janet_v_count(fields)) { uint32_t new_field_count = field_offset + janet_v_count(fields); linkage->field_defs = janet_realloc(linkage->field_defs, sizeof(JanetSysTypeField) * new_field_count); - memcpy(linkage->field_defs + field_offset, fields, janet_v_count(fields) * sizeof(JanetSysTypeField)); + safe_memcpy(linkage->field_defs + field_offset, fields, janet_v_count(fields) * sizeof(JanetSysTypeField)); linkage->field_def_count = new_field_count; janet_v_free(fields); } @@ -1332,7 +1341,7 @@ static void janet_sys_ir_init(JanetSysIR *out, JanetView instructions, JanetSysI /* Patch up name mapping arrays */ /* TODO - make more efficient, don't rebuild from scratch every time */ - if (linkage->type_names) janet_free((void *) linkage->type_names); + if (linkage->type_names) janet_v_free((void *) linkage->type_names); linkage->type_names = table_to_string_array(linkage->type_name_lookup, linkage->type_def_count); ir.register_names = table_to_string_array(ir.register_name_lookup, ir.register_count); @@ -1346,6 +1355,189 @@ static void janet_sys_ir_init(JanetSysIR *out, JanetView instructions, JanetSysI janet_array_push(linkage->ir_ordered, janet_wrap_abstract(out)); } +/* + * Passes + */ + +static JanetSysInstruction makethree(JanetSysInstruction source, JanetSysOp opcode, uint32_t dest, uint32_t lhs, uint32_t rhs) { + source.opcode = opcode; + source.three.dest = dest; + source.three.lhs = lhs; + source.three.rhs = rhs; + return source; +} + +static JanetSysInstruction maketwo(JanetSysInstruction source, JanetSysOp opcode, uint32_t dest, uint32_t src) { + source.opcode = opcode; + source.two.dest = dest; + source.two.src = src; + return source; +} + +static JanetSysInstruction makejmp(JanetSysInstruction source, JanetSysOp opcode, uint32_t to) { + source.opcode = opcode; + source.jump.to = to; + return source; +} + +static JanetSysInstruction makebranch(JanetSysInstruction source, JanetSysOp opcode, uint32_t cond, uint32_t labelid) { + source.opcode = opcode; + source.branch.cond = cond; + source.branch.to = labelid; + return source; +} + +static JanetSysInstruction makelabel(JanetSysInstruction source, JanetSysOp opcode, uint32_t id) { + source.opcode = opcode; + source.label.id = id; + return source; +} + +static JanetSysInstruction makebind(JanetSysInstruction source, JanetSysOp opcode, uint32_t reg, uint32_t type) { + source.opcode = opcode; + source.type_bind.dest = reg; + source.type_bind.type = type; + return source; +} + + +static uint32_t janet_sysir_getreg(JanetSysIR *sysir, uint32_t type) { + uint32_t ret = sysir->register_count++; + janet_v_push(sysir->types, type); + return ret; +} + +/* Find primitive types in the current linkage to avoid creating tons + * of copies of duplicate types. */ +static uint32_t janet_sysir_findprim(JanetSysIRLinkage *linkage, JanetPrim prim, const char *type_name) { + for (uint32_t i = 0; i < linkage->type_def_count; i++) { + if (linkage->type_defs[i].prim == prim) { + return i; + } + } + /* Add new type */ + JanetSysTypeInfo td; + memset(&td, 0, sizeof(td)); + td.prim = prim; + janet_v_push(linkage->type_defs, td); + janet_table_put(linkage->type_name_lookup, + janet_csymbolv(type_name), + janet_wrap_number(linkage->type_def_count)); + janet_v_push(linkage->type_names, janet_csymbol(type_name)); + return linkage->type_def_count++; +} + +/* Get a type that is a pointer to another type */ +static uint32_t janet_sysir_findpointer(JanetSysIRLinkage *linkage, uint32_t to, const char *type_name) { + for (uint32_t i = 0; i < linkage->type_def_count; i++) { + if (linkage->type_defs[i].prim == JANET_PRIM_POINTER) { + if (linkage->type_defs[i].pointer.type == to) { + return i; + } + } + } + /* Add new type */ + JanetSysTypeInfo td; + memset(&td, 0, sizeof(td)); + td.prim = JANET_PRIM_POINTER; + td.pointer.type = to; + janet_v_push(linkage->type_defs, td); + janet_table_put(linkage->type_name_lookup, + janet_csymbolv(type_name), + janet_wrap_number(linkage->type_def_count)); + janet_v_push(linkage->type_names, janet_csymbol(type_name)); + return linkage->type_def_count++; +} + +/* Unwrap vectorized binops to scalars in one pass to make certain lowering easier. */ +static void janet_sysir_scalarize(JanetSysIRLinkage *linkage) { + uint32_t index_type = janet_sysir_findprim(linkage, JANET_PRIM_U32, "U32Index"); + uint32_t boolean_type = janet_sysir_findprim(linkage, JANET_PRIM_BOOLEAN, "Boolean"); + for (int32_t j = 0; j < linkage->ir_ordered->count; j++) { + JanetSysIR *sysir = janet_unwrap_abstract(linkage->ir_ordered->data[j]); + for (uint32_t i = 0; i < sysir->instruction_count; i++) { + JanetSysInstruction instruction = sysir->instructions[i]; + sysir->error_ctx = janet_cstringv(janet_sysop_names[instruction.opcode]); + switch (instruction.opcode) { + default: + break; + case JANET_SYSOP_ADD: + case JANET_SYSOP_SUBTRACT: + case JANET_SYSOP_MULTIPLY: + case JANET_SYSOP_DIVIDE: + case JANET_SYSOP_BAND: + case JANET_SYSOP_BOR: + case JANET_SYSOP_BXOR: + case JANET_SYSOP_GT: + case JANET_SYSOP_LT: + case JANET_SYSOP_EQ: + case JANET_SYSOP_NEQ: + case JANET_SYSOP_GTE: + case JANET_SYSOP_LTE: + case JANET_SYSOP_SHL: + case JANET_SYSOP_SHR: + ; + { + uint32_t dest_type = janet_sys_optype(sysir, instruction.three.dest); + uint32_t test_type = dest_type; + if (linkage->type_defs[dest_type].prim == JANET_PRIM_POINTER) { + test_type = linkage->type_defs[dest_type].pointer.type; + } + if (linkage->type_defs[test_type].prim != JANET_PRIM_ARRAY) { + break; + } + uint32_t pel_type = janet_sysir_findpointer(linkage, linkage->type_defs[test_type].array.type, "PointerTo"); // fixme - type name would need to be unique + uint32_t lhs_type = janet_sys_optype(sysir, instruction.three.lhs); + uint32_t rhs_type = janet_sys_optype(sysir, instruction.three.rhs); + uint32_t array_size = linkage->type_defs[dest_type].array.fixed_count; + uint32_t index_reg = janet_sysir_getreg(sysir, index_type); + uint32_t compare_reg = janet_sysir_getreg(sysir, boolean_type); + uint32_t temp_lhs = janet_sysir_getreg(sysir, pel_type); + uint32_t temp_rhs = janet_sysir_getreg(sysir, pel_type); + uint32_t temp_dest = janet_sysir_getreg(sysir, pel_type); + uint32_t loopstart_label = sysir->label_count++; + uint32_t loopend_label = sysir->label_count++; + Janet labelkw_loopstart = janet_wrap_keyword(janet_symbol_gen()); + Janet labelkw_loopend = janet_wrap_keyword(janet_symbol_gen()); + JanetSysOp lhs_getp = (linkage->type_defs[lhs_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP; + JanetSysOp rhs_getp = (linkage->type_defs[rhs_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP; + JanetSysOp dest_getp = (linkage->type_defs[dest_type].prim == JANET_PRIM_POINTER) ? JANET_SYSOP_ARRAY_PGETP : JANET_SYSOP_ARRAY_GETP; + JanetSysInstruction patch[] = { + makebind(instruction, JANET_SYSOP_TYPE_BIND, index_reg, index_type), + makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_lhs, pel_type), + makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_rhs, pel_type), + makebind(instruction, JANET_SYSOP_TYPE_BIND, temp_dest, pel_type), + makebind(instruction, JANET_SYSOP_TYPE_BIND, compare_reg, boolean_type), + maketwo(instruction, JANET_SYSOP_LOAD, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(0))), + makelabel(instruction, JANET_SYSOP_LABEL, loopstart_label), + makethree(instruction, JANET_SYSOP_GTE, compare_reg, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(array_size))), + makebranch(instruction, JANET_SYSOP_BRANCH, compare_reg, loopend_label), + makethree(instruction, lhs_getp, temp_lhs, instruction.three.lhs, index_reg), + makethree(instruction, rhs_getp, temp_rhs, instruction.three.rhs, index_reg), + makethree(instruction, dest_getp, temp_dest, instruction.three.dest, index_reg), + makethree(instruction, instruction.opcode, temp_dest, temp_lhs, temp_rhs), + makethree(instruction, JANET_SYSOP_ADD, index_reg, index_reg, janet_sys_makeconst(sysir, index_type, janet_wrap_number(1))), + makejmp(instruction, JANET_SYSOP_JUMP, loopstart_label), + makelabel(instruction, JANET_SYSOP_LABEL, loopend_label) + }; + size_t patchcount = sizeof(patch) / sizeof(patch[0]); + janet_table_put(sysir->labels, labelkw_loopstart, janet_wrap_number(loopstart_label)); + janet_table_put(sysir->labels, labelkw_loopend, janet_wrap_number(loopend_label)); + janet_table_put(sysir->labels, janet_wrap_number(loopstart_label), janet_wrap_number(i + 1)); + janet_table_put(sysir->labels, janet_wrap_number(loopend_label), janet_wrap_number(i + patchcount - 1)); + size_t remaining = (sysir->instruction_count - i - 1) * sizeof(JanetSysInstruction); + sysir->instructions = janet_realloc(sysir->instructions, (sysir->instruction_count + patchcount - 1) * sizeof(JanetSysInstruction)); + if (remaining) memmove(sysir->instructions + i + patchcount, sysir->instructions + i + 1, remaining); + safe_memcpy(sysir->instructions + i, patch, sizeof(patch)); + i += patchcount - 2; + sysir->instruction_count += patchcount - 1; + break; + } + } + } + } +} + /* Lowering to C */ static const char *c_prim_names[] = { @@ -1917,10 +2109,10 @@ void janet_sys_ir_lower_to_ir(JanetSysIRLinkage *linkage, JanetArray *into) { static int sysir_gc(void *p, size_t s) { JanetSysIR *ir = (JanetSysIR *)p; (void) s; - janet_free(ir->constants); - janet_free(ir->types); + janet_v_free(ir->constants); + janet_v_free(ir->types); + janet_v_free(ir->register_names); janet_free(ir->instructions); - janet_free((void *) ir->register_names); return 0; } @@ -1949,8 +2141,8 @@ static int sysir_context_gc(void *p, size_t s) { JanetSysIRLinkage *linkage = (JanetSysIRLinkage *)p; (void) s; janet_free(linkage->field_defs); - janet_free(linkage->type_defs); - janet_free((void *) linkage->type_names); + janet_v_free(linkage->type_defs); + janet_v_free((void *) linkage->type_names); return 0; } @@ -2024,6 +2216,15 @@ JANET_CORE_FN(cfun_sysir_toir, return janet_wrap_array(array); } +JANET_CORE_FN(cfun_sysir_scalarize, + "(sysir/scalarize context)", + "Lower all vectorized instrinsics to loops of scalar operations.") { + janet_fixarity(argc, 1); + JanetSysIRLinkage *ir = janet_getabstract(argv, 0, &janet_sysir_context_type); + janet_sysir_scalarize(ir); + return argv[0]; +} + JANET_CORE_FN(cfun_sysir_tox64, "(sysir/to-x64 context &opt buffer target)", "Lower IR to x64 machine code.") { @@ -2052,6 +2253,7 @@ void janet_lib_sysir(JanetTable *env) { JanetRegExt cfuns[] = { JANET_CORE_REG("sysir/context", cfun_sysir_context), JANET_CORE_REG("sysir/asm", cfun_sysir_asm), + JANET_CORE_REG("sysir/scalarize", cfun_sysir_scalarize), JANET_CORE_REG("sysir/to-c", cfun_sysir_toc), JANET_CORE_REG("sysir/to-ir", cfun_sysir_toir), JANET_CORE_REG("sysir/to-x64", cfun_sysir_tox64),