1
0
mirror of https://github.com/janet-lang/janet synced 2025-08-04 21:13:51 +00:00

Work on some local type inference.

Right to left type inference in expressions for binary operators.
This commit is contained in:
Calvin Rose 2024-09-29 11:37:04 -05:00
parent a588f1f242
commit e96dd512f3
2 changed files with 212 additions and 166 deletions

View File

@ -12,13 +12,14 @@
# * tail call returns # * tail call returns
# * function definitions # * function definitions
# * arrays (declaration, loads, stores) # * arrays (declaration, loads, stores)
# * ...
# insight - using : inside symbols for types can be used to allow manipulating symbols with macros (defdyn *ret-type* "Current function return type")
(def slot-to-name @[]) (def slot-to-name @[])
(def name-to-slot @{}) (def name-to-slot @{})
(def type-to-name @[]) (def type-to-name @[])
(def name-to-type @{}) (def name-to-type @{})
(def slot-types @{})
(defn get-slot (defn get-slot
[&opt new-name] [&opt new-name]
@ -44,6 +45,24 @@
(assert t) (assert t)
t) t)
(defn binding-type
[name]
(def slot (assert (get name-to-slot name)))
(assert (get slot-types slot)))
(defn slot-type
[slot]
(assert (get slot-types slot)))
(defn assign-type
[name typ]
(def slot (get name-to-slot name))
(put slot-types slot typ))
(defn assign-slot-type
[slot typ]
(put slot-types slot typ))
(defn setup-default-types (defn setup-default-types
[ctx] [ctx]
(def into @[]) (def into @[])
@ -76,195 +95,215 @@
### Inside functions ### Inside functions
### ###
(defdyn *ret-type* "Return type hint if inside function body")
(defn visit1 (defn visit1
"Take in a form and compile code and put it into `into`. Return result slot." "Take in a form and compile code and put it into `into`. Return result slot."
[code into &opt no-return type-hint] [code into &opt no-return type-hint]
(cond (def subresult
(cond
# Compile a constant # Compile a constant
(string? code) ~(pointer ,code) (string? code) ~(pointer ,code)
(boolean? code) ~(boolean ,code) (boolean? code) ~(boolean ,code)
(number? code) ~(,(or type-hint 'long) ,code) # TODO - should default to double (number? code) ~(,(or type-hint 'double) ,code) # TODO - should default to double
# Needed? # Needed?
(= :core/u64 (type code)) ~(,(or type-hint 'ulong) ,code) (= :core/u64 (type code)) ~(ulong ,code)
(= :core/s64 (type code)) ~(,(or type-hint 'long) ,code) (= :core/s64 (type code)) ~(long ,code)
# Binding # Binding
(symbol? code) (symbol? code)
(named-slot code) (named-slot code)
# Compile forms # Compile forms
(and (tuple? code) (= :parens (tuple/type code))) (and (tuple? code) (= :parens (tuple/type code)))
(do (do
(assert (> (length code) 0)) (assert (> (length code) 0))
(def [op & args] code) (def [op & args] code)
(case op (case op
# Arithmetic # Arithmetic
'+ (do-binop 'add args into type-hint) '+ (do-binop 'add args into type-hint)
'- (do-binop 'subtract args into type-hint) '- (do-binop 'subtract args into type-hint)
'* (do-binop 'multiply args into type-hint) '* (do-binop 'multiply args into type-hint)
'/ (do-binop 'divide args into type-hint) '/ (do-binop 'divide args into type-hint)
'<< (do-binop 'shl args into type-hint) '<< (do-binop 'shl args into type-hint)
'>> (do-binop 'shr args into type-hint) '>> (do-binop 'shr args into type-hint)
# Comparison # Comparison
'= (do-comp 'eq args into) '= (do-comp 'eq args into)
'not= (do-comp 'neq args into) 'not= (do-comp 'neq args into)
'< (do-comp 'lt args into) '< (do-comp 'lt args into)
'<= (do-comp 'lte args into) '<= (do-comp 'lte args into)
'> (do-comp 'gt args into) '> (do-comp 'gt args into)
'>= (do-comp 'gte args into) '>= (do-comp 'gte args into)
# Type hinting # Type hinting
'the 'the
(do (do
(assert (= 2 (length args))) (assert (= 2 (length args)))
(def [xtype x] args) (def [xtype x] args)
(def result (visit1 x into false xtype)) (def result (visit1 x into false xtype))
(if (tuple? result) # constant (if (tuple? result) # constant
(let [[t y] result] (let [[t y] result]
(assertf (= t xtype) "type mismatch, %p doesn't match %p" t xtype) (assertf (= t xtype) "type mismatch, %p doesn't match %p" t xtype)
[xtype y]) [xtype y])
(do (do
(array/push into ~(bind ,result ,xtype)) (array/push into ~(bind ,result ,xtype))
result))) result)))
# Named bindings # Named bindings
'def 'def
(do (do
(assert (= 2 (length args))) (assert (= 2 (length args)))
(def [full-name value] args) (def [full-name value] args)
(assert (symbol? full-name)) (assert (symbol? full-name))
(def [name tp] (type-extract full-name 'int)) (def [name tp] (type-extract full-name 'int))
(def result (visit1 value into false tp)) (def result (visit1 value into false tp))
(def slot (get-slot name)) (def slot (get-slot name))
(when tp (assign-type name tp)
(array/push into ~(bind ,slot ,tp))) (array/push into ~(bind ,slot ,tp))
(array/push into ~(move ,slot ,result)) (array/push into ~(move ,slot ,result))
slot) slot)
# Named variables # Named variables
'var 'var
(do (do
(assert (= 2 (length args))) (assert (= 2 (length args)))
(def [full-name value] args) (def [full-name value] args)
(assert (symbol? full-name)) (assert (symbol? full-name))
(def [name tp] (type-extract full-name 'int)) (def [name tp] (type-extract full-name 'int))
(def result (visit1 value into false tp)) (def result (visit1 value into false tp))
(def slot (get-slot name)) (def slot (get-slot name))
(when tp (assign-type name tp)
(array/push into ~(bind ,slot ,tp))) (array/push into ~(bind ,slot ,tp))
(array/push into ~(move ,slot ,result)) (array/push into ~(move ,slot ,result))
slot) slot)
# Assignment # Assignment
'set 'set
(do (do
(assert (= 2 (length args))) (assert (= 2 (length args)))
(def [to x] args) (def [to x] args)
(def result (visit1 x into false)) (def type-hint (binding-type to))
(def toslot (named-slot to)) (def result (visit1 x into false type-hint))
(array/push into ~(move ,toslot ,result)) (def toslot (named-slot to))
toslot) (array/push into ~(move ,toslot ,result))
toslot)
# Return # Return
'return 'return
(do (do
(assert (>= 1 (length args))) (assert (>= 1 (length args)))
(if (empty? args) (if (empty? args)
(array/push into '(return)) (array/push into '(return))
(do (do
(def [x] args) (def [x] args)
(array/push into ~(return ,(visit1 x into false (dyn *ret-type*)))))) (array/push into ~(return ,(visit1 x into false (dyn *ret-type*))))))
nil) nil)
# Sequence of operations # Sequence of operations
'do 'do
(do (do
(each form (slice args 0 -2) (visit1 form into true)) (each form (slice args 0 -2) (visit1 form into true))
(visit1 (last args) into false type-hint)) (visit1 (last args) into false type-hint))
# While loop # While loop
'while 'while
(do (do
(def lab-test (keyword (gensym))) (def lab-test (keyword (gensym)))
(def lab-exit (keyword (gensym))) (def lab-exit (keyword (gensym)))
(assert (< 1 (length args))) (assert (< 1 (length args)))
(def [cnd & body] args) (def [cnd & body] args)
(array/push into lab-test) (array/push into lab-test)
(def condition-slot (visit1 cnd into false 'boolean)) (def condition-slot (visit1 cnd into false 'boolean))
(array/push into ~(branch-not ,condition-slot ,lab-exit)) (array/push into ~(branch-not ,condition-slot ,lab-exit))
(each code body (each code body
(visit1 code into true)) (visit1 code into true))
(array/push into ~(jump ,lab-test)) (array/push into ~(jump ,lab-test))
(array/push into lab-exit) (array/push into lab-exit)
nil) nil)
# Branch # Branch
'if 'if
(do (do
(def lab (keyword (gensym))) (def lab (keyword (gensym)))
(def lab-end (keyword (gensym))) (def lab-end (keyword (gensym)))
(assert (< 2 (length args) 4)) (assert (< 2 (length args) 4))
(def [cnd tru fal] args) (def [cnd tru fal] args)
(def condition-slot (visit1 cnd into false 'boolean)) (def condition-slot (visit1 cnd into false 'boolean))
(def ret (get-slot)) (def ret (get-slot))
(array/push into ~(bind ,ret ,type-hint)) (array/push into ~(bind ,ret ,type-hint))
(array/push into ~(branch ,condition-slot ,lab)) (array/push into ~(branch ,condition-slot ,lab))
# false path # false path
(array/push into ~(move ,ret ,(visit1 tru into false type-hint))) (array/push into ~(move ,ret ,(visit1 tru into false type-hint)))
(array/push into ~(jump ,lab-end)) (array/push into ~(jump ,lab-end))
(array/push into lab) (array/push into lab)
# true path # true path
(array/push into ~(move ,ret ,(visit1 fal into false type-hint))) (array/push into ~(move ,ret ,(visit1 fal into false type-hint)))
(array/push into lab-end) (array/push into lab-end)
ret) ret)
# Insert IR # Insert IR
'ir 'ir
(do (array/push into ;args) nil) (do
(assert no-return)
(array/push into ;args)
nil)
# Syscall # Syscall
'syscall 'syscall
(do (do
(def slots @[]) (def slots @[])
(def ret (if no-return nil (get-slot))) (def ret (if no-return nil (get-slot)))
(each arg args (each arg args
(array/push slots (visit1 arg into))) (array/push slots (visit1 arg into)))
(array/push into ~(syscall :default ,ret ,;slots)) (array/push into ~(syscall :default ,ret ,;slots))
ret) ret)
# Assume function call # Assume function call
(do (do
(def slots @[]) (def slots @[])
(def ret (if no-return nil (get-slot))) (def ret (if no-return nil (get-slot)))
(each arg args (each arg args
(array/push slots (visit1 arg into))) (array/push slots (visit1 arg into)))
(array/push into ~(call :default ,ret [pointer ,op] ,;slots)) (array/push into ~(call :default ,ret [pointer ,op] ,;slots))
ret))) ret)))
(errorf "cannot compile %q" code))) (errorf "cannot compile %q" code)))
# Check type-hint matches return type
(if type-hint
(when-let [t (first subresult)] # TODO - Disallow empty types
(assert (= type-hint t) (string/format "%j, expected type %v, got %v" code type-hint t))))
subresult)
(varfn do-binop (varfn do-binop
"Emit an operation such as (+ x y). "Emit an operation such as (+ x y).
Extended to support any number of arguments such as (+ x y z ...)" Extended to support any number of arguments such as (+ x y z ...)"
[opcode args into type-hint] [opcode args into type-hint]
(var typ type-hint)
(var final nil) (var final nil)
(def slots @[])
(each arg args (each arg args
(def right (visit1 arg into false type-hint)) (def right (visit1 arg into false typ))
(when (number? right) (array/push slots right))
# If we don't have a type hint, infer types from bottom up
(when (nil? typ)
(when-let [new-typ (get slot-types right)]
(set typ new-typ)))
(set final (set final
(if final (if final
(let [result (get-slot)] (let [result (get-slot)]
# TODO - finish type inference - we should be able to omit the bind (array/push slots result)
# call and sysir should be able to infer the type
(array/push into ~(bind ,result int)) # Why int?
(array/push into ~(,opcode ,result ,final ,right)) (array/push into ~(,opcode ,result ,final ,right))
result) result)
right))) right)))
(assert typ (string "unable to infer type for %j" [opcode ;args]))
(each slot (distinct slots)
(array/push into ~(bind ,slot ,typ)))
(assert final)) (assert final))
(varfn do-comp (varfn do-comp
@ -278,8 +317,13 @@
(array/push into ~(bind ,temp-result boolean))) (array/push into ~(bind ,temp-result boolean)))
(var left nil) (var left nil)
(var first-compare true) (var first-compare true)
(var typ nil)
(each arg args (each arg args
(def right (visit1 arg into)) (def right (visit1 arg into false typ))
# If we don't have a type hint, infer types from bottom up
(when (nil? typ)
(when-let [new-typ (get slot-types right)]
(set typ new-typ)))
(when left (when left
(if first-compare (if first-compare
(array/push into ~(,opcode ,result ,left ,right)) (array/push into ~(,opcode ,result ,left ,right))
@ -306,6 +350,7 @@
(do (do
# TODO doc strings # TODO doc strings
(table/clear name-to-slot) (table/clear name-to-slot)
(table/clear slot-types)
(array/clear slot-to-name) (array/clear slot-to-name)
(def [name args & body] rest) (def [name args & body] rest)
(assert (tuple? args)) (assert (tuple? args))
@ -317,11 +362,12 @@
(each arg args (each arg args
(def [name tp] (type-extract arg 'int)) (def [name tp] (type-extract arg 'int))
(def slot (get-slot name)) (def slot (get-slot name))
(assign-type name tp)
(array/push ir-asm ~(bind ,slot ,tp))) (array/push ir-asm ~(bind ,slot ,tp)))
(with-dyns [*ret-type* fn-tp] (with-dyns [*ret-type* fn-tp]
(each part body (each part body
(visit1 part ir-asm true))) (visit1 part ir-asm true)))
# (eprintf "%.99M" ir-asm) (eprintf "%.99M" ir-asm)
(sysir/asm ctx ir-asm)) (sysir/asm ctx ir-asm))
(errorf "unknown form %v" form))) (errorf "unknown form %v" form)))
@ -348,7 +394,7 @@
(defn dumpx64-windows (defn dumpx64-windows
[] []
(print (sysir/to-x64 ctx @"" :windows))) (print (sysir/to-x64 ctx @"" :windows)))
(defn dumpc (defn dumpc
[] []
(print (sysir/to-c ctx))) (print (sysir/to-c ctx)))

View File

@ -2,7 +2,7 @@
(def square (def square
'(defn square:int [num:int] '(defn square:int [num:int]
(return (* num num)))) (return (* 1 num num))))
(def simple (def simple
'(defn simple:int [x:int] '(defn simple:int [x:int]
@ -13,10 +13,10 @@
'(defn myprog:int [] '(defn myprog:int []
(def xyz:int (+ 1 2 3)) (def xyz:int (+ 1 2 3))
(def abc:int (* 4 5 6)) (def abc:int (* 4 5 6))
(def x:boolean (= (the int 5) xyz)) (def x:boolean (= xyz 5))
(var i:int 0) (var i:int 0)
(while (< i (the int 10)) (while (< i 10)
(set i (the int (+ 1 i))) (set i (+ 1 i))
(printf "i = %d\n" i)) (printf "i = %d\n" i))
(printf "hello, world!\n%d\n" (the int (if x abc xyz))) (printf "hello, world!\n%d\n" (the int (if x abc xyz)))
#(return (* abc xyz)))) #(return (* abc xyz))))
@ -46,5 +46,5 @@
(compile1 doloop) (compile1 doloop)
(compile1 main-fn) (compile1 main-fn)
#(dump) #(dump)
#(dumpc) (dumpc)
(dumpx64) #(dumpx64)