mirror of
https://github.com/osmarks/osmarkscalculator.git
synced 2024-12-22 00:40:25 +00:00
initial version
This commit is contained in:
commit
b8eddc0837
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
/target
|
||||
osmarkscalculator.zip
|
||||
osmarkscalculator.tar
|
||||
src.zip
|
167
Cargo.lock
generated
Normal file
167
Cargo.lock
generated
Normal file
@ -0,0 +1,167 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "61604a8f862e1d5c3229fdd78f8b02c68dcf73a4c4b05fd636d12240aaa242c1"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e54ea8bc3fb1ee042f5aace6e3c6e025d3874866da222930f70ce62aceba0bfa"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97242a70df9b89a65d0b6df3c4bf5b9ce03c5b7309019777fbde37e7537f8762"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"lazy_static",
|
||||
"memoffset",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfcae03edb34f947e64acdb1c33ec169824e20657e9ecb61cef6c8c74dcb8120"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inlinable_string"
|
||||
version = "0.1.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3094308123a0e9fd59659ce45e22de9f53fc1d2ac6e1feb9fef988e4f76cad77"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.114"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0005d08a8f7b65fb8073cb697aa0b12b631ed251ce73d862ce50eeb52ce3b50"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "osmarkscalculator"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"inlinable_string",
|
||||
"itertools",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"crossbeam-deque",
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
"lazy_static",
|
||||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
12
Cargo.toml
Normal file
12
Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "osmarkscalculator"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
inlinable_string = "0.1"
|
||||
rayon = "1.5"
|
||||
itertools = "0.10"
|
57
src/env.rs
Normal file
57
src/env.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use inlinable_string::InlinableString;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::value::Value;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RuleResult {
|
||||
Exp(Value),
|
||||
Intrinsic(usize)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Rule {
|
||||
pub condition: Value,
|
||||
pub result: RuleResult,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Operation {
|
||||
pub commutative: bool,
|
||||
pub associative: bool,
|
||||
}
|
||||
|
||||
pub type Bindings = HashMap<InlinableString, Value>;
|
||||
pub type Ops = HashMap<InlinableString, Operation>;
|
||||
pub type Ruleset = HashMap<InlinableString, Vec<Rule>>;
|
||||
pub type Intrinsics = HashMap<usize, Box<dyn Fn(&Bindings) -> Result<Value> + Sync + Send>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Env {
|
||||
pub ops: Arc<Ops>,
|
||||
pub ruleset: Vec<Arc<Ruleset>>,
|
||||
pub intrinsics: Arc<Intrinsics>,
|
||||
pub bindings: Bindings
|
||||
}
|
||||
|
||||
impl Env {
|
||||
// Get details about an operation, falling back to the default of not commutative/associative if none exists
|
||||
pub fn get_op(&self, name: &InlinableString) -> Operation {
|
||||
self.ops.get(name).map(Clone::clone).unwrap_or(Operation { commutative: false, associative: false })
|
||||
}
|
||||
|
||||
// Make a new Env with this extra ruleset in it
|
||||
pub fn with_ruleset(&self, ruleset: Arc<Ruleset>) -> Self {
|
||||
let mut new_env = self.clone();
|
||||
new_env.ruleset.push(ruleset);
|
||||
new_env
|
||||
}
|
||||
|
||||
pub fn with_bindings(&self, bindings: &Bindings) -> Self {
|
||||
let mut new_env = self.clone();
|
||||
new_env.bindings = bindings.clone();
|
||||
new_env
|
||||
}
|
||||
}
|
633
src/main.rs
Normal file
633
src/main.rs
Normal file
@ -0,0 +1,633 @@
|
||||
use anyhow::{Result, Context, bail};
|
||||
use inlinable_string::InlinableString;
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
use std::borrow::Cow;
|
||||
use std::convert::TryInto;
|
||||
use std::sync::Arc;
|
||||
use rayon::prelude::*;
|
||||
|
||||
mod parse;
|
||||
mod value;
|
||||
mod util;
|
||||
mod env;
|
||||
|
||||
use value::Value;
|
||||
use env::{Rule, Ruleset, Env, Bindings, RuleResult, Operation};
|
||||
|
||||
// Main pattern matcher function;
|
||||
fn match_and_bind(expr: &Value, rule: &Rule, env: &Env) -> Result<Option<Value>> {
|
||||
fn go(expr: &Value, cond: &Value, env: &Env, already_bound: &Bindings) -> Result<Option<Bindings>> {
|
||||
match (expr, cond) {
|
||||
// numbers match themselves
|
||||
(Value::Num(a), Value::Num(b)) => if a == b { Ok(Some(HashMap::new())) } else { Ok(None) },
|
||||
// handle predicated value - check all predicates, succeed with binding if they match
|
||||
(val, Value::Call(x, args)) if x == "#" => {
|
||||
let preds = &args[1..];
|
||||
let (mut success, mut bindings) = match go(val, &args[0], env, already_bound)? {
|
||||
Some(bindings) => (true, bindings),
|
||||
None => (false, already_bound.clone())
|
||||
};
|
||||
|
||||
for pred in preds {
|
||||
match pred {
|
||||
// "Num" predicate matches successfully if something is a number
|
||||
Value::Identifier(i) if i.as_ref() == "Num" => {
|
||||
match val {
|
||||
Value::Num(_) => (),
|
||||
_ => success = false
|
||||
}
|
||||
},
|
||||
// "Ident" does the same for idents
|
||||
Value::Identifier(i) if i.as_ref() == "Ident" => {
|
||||
match val {
|
||||
Value::Identifier(_) => (),
|
||||
_ => success = false
|
||||
}
|
||||
},
|
||||
// Invert match success
|
||||
Value::Identifier(i) if i.as_ref() == "Not" => {
|
||||
success = !success
|
||||
},
|
||||
Value::Call(head, args) if head.as_ref() == "And" => {
|
||||
// Try all patterns it's given, and if any fails then fail the match
|
||||
for arg in args.iter() {
|
||||
match go(val, arg, env, &bindings)? {
|
||||
Some(new_bindings) => bindings.extend(new_bindings),
|
||||
None => success = false
|
||||
}
|
||||
}
|
||||
},
|
||||
Value::Call(head, args) if head.as_ref() == "Eq" => {
|
||||
// Evaluate all arguments and check if they are equal
|
||||
let mut compare_against = None;
|
||||
for arg in args.iter() {
|
||||
let mut evaluated_value = arg.subst(&bindings);
|
||||
run_rewrite(&mut evaluated_value, env).context("evaluating Eq predicate")?;
|
||||
match compare_against {
|
||||
Some(ref x) => if x != &evaluated_value {
|
||||
success = false
|
||||
},
|
||||
None => compare_against = Some(evaluated_value)
|
||||
}
|
||||
}
|
||||
},
|
||||
Value::Call(head, args) if head.as_ref() == "Gte" => {
|
||||
// Evaluate all arguments and do comparison.
|
||||
let mut x = args[0].subst(&bindings);
|
||||
let mut y = args[1].subst(&bindings);
|
||||
run_rewrite(&mut x, env).context("evaluating Gte predicate")?;
|
||||
run_rewrite(&mut y, env).context("evaluating Gte predicate")?;
|
||||
success &= x >= y;
|
||||
},
|
||||
Value::Call(head, args) if head.as_ref() == "Or" => {
|
||||
// Tries all patterns it's given and will set the match to successful if *any* of them works
|
||||
for arg in args.iter() {
|
||||
match go(val, arg, env, &bindings)? {
|
||||
Some(new_bindings) => {
|
||||
bindings.extend(new_bindings);
|
||||
success = true
|
||||
},
|
||||
None => ()
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => bail!("invalid predicate {:?}", pred)
|
||||
}
|
||||
}
|
||||
Ok(match success {
|
||||
true => Some(bindings),
|
||||
false => None
|
||||
})
|
||||
},
|
||||
(Value::Call(exp_head, exp_args), Value::Call(rule_head, rule_args)) => {
|
||||
let mut exp_args = Cow::Borrowed(exp_args);
|
||||
// Regardless of any special casing for associativity etc., different heads mean rules can never match
|
||||
if exp_head != rule_head { return Ok(None) }
|
||||
|
||||
let op = env.get_op(exp_head);
|
||||
|
||||
// Copy bindings from the upper-level matching, so that things like "a+(b+a)" work.
|
||||
let mut out_bindings = already_bound.clone();
|
||||
|
||||
// Special case for associative expressions: split off extra arguments into a new tree
|
||||
if op.associative && rule_args.len() < exp_args.len() {
|
||||
let exp_args = exp_args.to_mut();
|
||||
let rest = exp_args.split_off(rule_args.len() - 1);
|
||||
let rem = Value::Call(exp_head.clone(), rest);
|
||||
exp_args.push(rem);
|
||||
}
|
||||
if rule_args.len() != exp_args.len() { return Ok(None) }
|
||||
|
||||
// Try and match all "adjacent" arguments to each other
|
||||
for (rule_arg, exp_arg) in rule_args.iter().zip(&*exp_args) {
|
||||
match go(exp_arg, rule_arg, env, &out_bindings)? {
|
||||
Some(x) => out_bindings.extend(x),
|
||||
None => return Ok(None)
|
||||
}
|
||||
}
|
||||
Ok(Some(out_bindings))
|
||||
},
|
||||
// identifier pattern matches anything, unless the identifier has already been bound to something else
|
||||
(x, Value::Identifier(a)) => {
|
||||
if let Some(b) = already_bound.get(a) {
|
||||
if b != x {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
Ok(Some(vec![(a.clone(), x.clone())].into_iter().collect()))
|
||||
},
|
||||
// anything else doesn't match
|
||||
_ => Ok(None)
|
||||
}
|
||||
}
|
||||
// special case at top level of expression - try to match pattern to different subranges of input
|
||||
match (expr, &rule.condition) {
|
||||
// this only applies to matching one "call" against a "call" pattern (with same head)
|
||||
(Value::Call(ehead, eargs), Value::Call(rhead, rargs)) => {
|
||||
// and also only to associative operations
|
||||
if env.get_op(ehead).associative && eargs.len() > rargs.len() && ehead == rhead {
|
||||
// consider all possible subranges of the arguments of the appropriate length
|
||||
for range_start in 0..=(eargs.len() - rargs.len()) {
|
||||
// extract the arguments & convert into new Value
|
||||
let c_args = eargs[range_start..range_start + rargs.len()].iter().cloned().collect();
|
||||
let c_call = Value::Call(ehead.clone(), c_args);
|
||||
// attempt to match the new subrange against the current rule
|
||||
if let Some(r) = match_and_bind(&c_call, rule, env)? {
|
||||
// generate new output with result
|
||||
let mut new_args = Vec::with_capacity(3);
|
||||
// add back extra start items
|
||||
if range_start != 0 {
|
||||
new_args.push(Value::Call(ehead.clone(), eargs[0..range_start].iter().cloned().collect()))
|
||||
}
|
||||
new_args.push(r);
|
||||
// add back extra end items
|
||||
if range_start + rargs.len() != eargs.len() {
|
||||
new_args.push(Value::Call(ehead.clone(), eargs[range_start + rargs.len()..eargs.len()].iter().cloned().collect()))
|
||||
}
|
||||
let new_exp = Value::Call(ehead.clone(), new_args);
|
||||
return Ok(Some(new_exp))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => ()
|
||||
}
|
||||
// substitute bindings from matching into either an intrinsic or the output of the rule
|
||||
if let Some(bindings) = go(expr, &rule.condition, env, &HashMap::new())? {
|
||||
Ok(Some(match &rule.result {
|
||||
RuleResult::Intrinsic(id) => env.intrinsics.get(id).unwrap()(&bindings).with_context(|| format!("applying intrinsic {}", id))?,
|
||||
RuleResult::Exp(e) => e.subst(&bindings)
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort any commutative expressions
|
||||
fn canonical_sort(v: &mut Value, env: &Env) -> Result<()> {
|
||||
match v {
|
||||
Value::Call(head, args) => if env.get_op(head).commutative {
|
||||
args.sort();
|
||||
},
|
||||
_ => ()
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Associative expression flattening.
|
||||
fn flatten_tree(v: &mut Value, env: &Env) -> Result<()> {
|
||||
match v {
|
||||
Value::Call(head, args) => {
|
||||
if env.get_op(head).associative {
|
||||
// Not the most efficient algorithm, but does work.
|
||||
// Repeatedly find the position of a flatten-able child node, and splice it into the argument list.
|
||||
loop {
|
||||
let mut move_pos = None;
|
||||
for (i, child) in args.iter().enumerate() {
|
||||
if let Some(child_head) = child.head() {
|
||||
if *head == child_head {
|
||||
move_pos = Some(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
match move_pos {
|
||||
Some(pos) => {
|
||||
let removed = std::mem::replace(&mut args[pos], Value::Num(0));
|
||||
// We know that removed will be a Call (because its head wasn't None earlier). Unfortunately, rustc does not know this.
|
||||
match removed {
|
||||
Value::Call(_, removed_child_args) => args.splice(pos..=pos, removed_child_args.into_iter()),
|
||||
_ => unreachable!()
|
||||
};
|
||||
},
|
||||
None => break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Also do sorting after flattening, to avoid any weirdness with ordering.
|
||||
canonical_sort(v, env)?;
|
||||
return Ok(())
|
||||
},
|
||||
_ => return Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Applies rewrite rulesets to an expression.
|
||||
fn run_rewrite(v: &mut Value, env: &Env) -> Result<()> {
|
||||
loop {
|
||||
// Compare original and final hash instead of storing a copy of the original value and checking equality
|
||||
// Collision probability is negligible and this is substantially faster than storing/comparing copies.
|
||||
let original_hash = v.get_hash();
|
||||
|
||||
flatten_tree(v, env).context("flattening tree")?;
|
||||
// Call expressions can be rewritten using pattern matching rules; identifiers can be substituted for bindings if available
|
||||
match v {
|
||||
Value::Call(head, args) => {
|
||||
let head = head.clone();
|
||||
|
||||
// Rewrite sub-expressions using existing environment
|
||||
args.par_iter_mut().try_for_each(|arg| run_rewrite(arg, env).with_context(|| format!("rewriting {}", arg.render_to_string(env))))?;
|
||||
|
||||
// Try to apply all applicable rules from all rulesets, in sequence
|
||||
for ruleset in env.ruleset.iter() {
|
||||
if let Some(rules) = ruleset.get(&head) {
|
||||
// Within a ruleset, rules are applied backward. This is nicer for users using the program interactively.
|
||||
for rule in rules.iter().rev() {
|
||||
if let Some(result) = match_and_bind(v, rule, env).with_context(|| format!("applying rule {} -> {:?}", rule.condition.render_to_string(env), rule.result))? {
|
||||
*v = result;
|
||||
flatten_tree(v, env).context("flattening tree after rule application")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
// Substitute in bindings which have been provided
|
||||
Value::Identifier(ident) => {
|
||||
match env.bindings.get(ident) {
|
||||
Some(val) => {
|
||||
*v = val.clone();
|
||||
},
|
||||
None => return Ok(())
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
return Ok(())
|
||||
}
|
||||
}
|
||||
if original_hash == v.get_hash() {
|
||||
break
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Utility function for defining intrinsic functions for binary operators.
|
||||
// Converts a function which does the actual operation to a function from bindings to a value.
|
||||
fn wrap_binop<F: 'static + Fn(i128, i128) -> Result<i128> + Sync + Send>(op: F) -> Box<dyn Fn(&Bindings) -> Result<Value> + Sync + Send> {
|
||||
Box::new(move |bindings: &Bindings| {
|
||||
let a = bindings.get(&InlinableString::from("a")).context("binop missing first argument")?.assert_num("binop first argument")?;
|
||||
let b = bindings.get(&InlinableString::from("b")).context("binop missing second argument")?.assert_num("binop second argument")?;
|
||||
op(a, b).map(Value::Num)
|
||||
})
|
||||
}
|
||||
|
||||
// Provides a basic environment with operator commutativity/associativity operations and intrinsics.
|
||||
fn make_initial_env() -> Env {
|
||||
let mut ops = HashMap::new();
|
||||
ops.insert(InlinableString::from("+"), Operation { commutative: true, associative: true });
|
||||
ops.insert(InlinableString::from("*"), Operation { commutative: true, associative: true });
|
||||
ops.insert(InlinableString::from("-"), Operation { commutative: false, associative: false });
|
||||
ops.insert(InlinableString::from("/"), Operation { commutative: false, associative: false });
|
||||
ops.insert(InlinableString::from("^"), Operation { commutative: false, associative: false });
|
||||
ops.insert(InlinableString::from("="), Operation { commutative: false, associative: false });
|
||||
ops.insert(InlinableString::from("#"), Operation { commutative: false, associative: true });
|
||||
let ops = Arc::new(ops);
|
||||
let mut intrinsics = HashMap::new();
|
||||
intrinsics.insert(0, wrap_binop(|a, b| a.checked_add(b).context("integer overflow")));
|
||||
intrinsics.insert(1, wrap_binop(|a, b| a.checked_sub(b).context("integer overflow")));
|
||||
intrinsics.insert(2, wrap_binop(|a, b| a.checked_mul(b).context("integer overflow")));
|
||||
intrinsics.insert(3, wrap_binop(|a, b| a.checked_div(b).context("division by zero")));
|
||||
intrinsics.insert(4, wrap_binop(|a, b| {
|
||||
// The "pow" function takes a usize (machine-sized unsigned integer) and an i128 may not fit into this, so an extra conversion is needed
|
||||
Ok(a.pow(b.try_into()?))
|
||||
}));
|
||||
intrinsics.insert(5, Box::new(|bindings| {
|
||||
// Substitute a single, given binding var=value into a target expression
|
||||
let var = bindings.get(&InlinableString::from("var")).unwrap();
|
||||
let value = bindings.get(&InlinableString::from("value")).unwrap();
|
||||
let target = bindings.get(&InlinableString::from("target")).unwrap();
|
||||
let name = var.assert_ident("Subst")?;
|
||||
let mut new_bindings = HashMap::new();
|
||||
new_bindings.insert(name, value.clone());
|
||||
Ok(target.subst(&new_bindings))
|
||||
}));
|
||||
intrinsics.insert(6, wrap_binop(|a, b| a.checked_rem(b).context("division by zero")));
|
||||
let intrinsics = Arc::new(intrinsics);
|
||||
Env {
|
||||
ruleset: vec![],
|
||||
ops: ops.clone(),
|
||||
intrinsics: intrinsics.clone(),
|
||||
bindings: HashMap::new()
|
||||
}
|
||||
}
|
||||
|
||||
const BUILTINS: &str = "
|
||||
SetStage[all]
|
||||
a#Num + b#Num = Intrinsic[0]
|
||||
a#Num - b#Num = Intrinsic[1]
|
||||
a#Num * b#Num = Intrinsic[2]
|
||||
a#Num / b#Num = Intrinsic[3]
|
||||
a#Num ^ b#Num = Intrinsic[4]
|
||||
Subst[var=value, target] = Intrinsic[5]
|
||||
Mod[a#Num, b#Num] = Intrinsic[6]
|
||||
PushRuleset[builtins]
|
||||
";
|
||||
|
||||
const GENERAL_RULES: &str = "
|
||||
SetStage[all]
|
||||
(a*b#Num)+(a*c#Num) = (b+c)*a
|
||||
Negate[a] = 0 - a
|
||||
a^b*a^c = a^(b+c)
|
||||
a^0 = 1
|
||||
a^1 = a
|
||||
(a^b)^c = a^(b*c)
|
||||
0*a = 0
|
||||
0+a = a
|
||||
1*a = a
|
||||
x/x = 1
|
||||
(n*x)/x = n
|
||||
PushRuleset[general_rules]
|
||||
";
|
||||
|
||||
const NORMALIZATION_RULES: &str = "
|
||||
SetStage[norm]
|
||||
a/b = a*b^Negate[1]
|
||||
a+b#Num*a = (b+1)*a
|
||||
a^b#Num#Gte[b, 2] = a*a^(b-1)
|
||||
a-c#Num*b = a+Negate[c]*b
|
||||
a+a = 2*a
|
||||
a*(b+c) = a*b+a*c
|
||||
a-b = a+Negate[1]*b
|
||||
PushRuleset[normalization]
|
||||
";
|
||||
|
||||
const DENORMALIZATION_RULES: &str = "
|
||||
SetStage[denorm]
|
||||
a*a = a^2
|
||||
a^b#Num*a = a^(b+1)
|
||||
c+a*b#Num#Gte[0, b] = c-a*Negate[b]
|
||||
PushRuleset[denormalization]
|
||||
";
|
||||
|
||||
const DIFFERENTIATION_DEFINITION: &str = "
|
||||
SetStage[all]
|
||||
D[x, x] = 1
|
||||
D[a#Num, x] = 0
|
||||
D[f+g, x] = D[f, x] + D[g, x]
|
||||
D[f*g, x] = D[f, x] * g + D[g, x] * f
|
||||
D[a#Num*f, x] = a * D[f, x]
|
||||
PushRuleset[differentiation]
|
||||
";
|
||||
|
||||
const FACTOR_DEFINITION: &str = "
|
||||
SetStage[post_norm]
|
||||
Factor[x, a*x+b] = x * (a + Factor[x, b] / x)
|
||||
PushRuleset[factor]
|
||||
SetStage[pre_denorm]
|
||||
Factor[x, a] = a
|
||||
PushRuleset[factor_postprocess]
|
||||
SetStage[denorm]
|
||||
x^n/x = x^(n-1)
|
||||
(a*x^n)/x = a*x^(n-1)
|
||||
PushRuleset[factor_postpostprocess]
|
||||
";
|
||||
|
||||
struct ImperativeCtx {
|
||||
bindings: Bindings,
|
||||
current_ruleset_stage: InlinableString,
|
||||
current_ruleset: Ruleset,
|
||||
rulesets: HashMap<InlinableString, Arc<Ruleset>>,
|
||||
stages: Vec<(InlinableString, Vec<InlinableString>)>,
|
||||
base_env: Env
|
||||
}
|
||||
|
||||
impl ImperativeCtx {
|
||||
// Make a new imperative context
|
||||
// Stages are currently hardcoded, as adding a way to manage them would add lots of complexity
|
||||
// for limited benefit
|
||||
fn init() -> Self {
|
||||
let stages = [
|
||||
"pre_norm",
|
||||
"norm",
|
||||
"post_norm",
|
||||
"pre_denorm",
|
||||
"denorm",
|
||||
"post_denorm"
|
||||
].iter().map(|name| (InlinableString::from(*name), vec![])).collect();
|
||||
ImperativeCtx {
|
||||
bindings: HashMap::new(),
|
||||
current_ruleset_stage: InlinableString::from("post_norm"),
|
||||
current_ruleset: HashMap::new(),
|
||||
rulesets: HashMap::new(),
|
||||
stages,
|
||||
base_env: make_initial_env()
|
||||
}
|
||||
}
|
||||
|
||||
// Insert a rule into the current ruleset; handles switching out the result for a relevant intrinsic use, generating possible reorderings, and inserting into the lookup map.
|
||||
fn insert_rule(&mut self, condition: &Value, result_val: Value) -> Result<()> {
|
||||
let result = match result_val {
|
||||
Value::Call(head, args) if head == "Intrinsic" => RuleResult::Intrinsic(args[0].assert_num("Intrinsic ID")? as usize),
|
||||
_ => RuleResult::Exp(result_val)
|
||||
};
|
||||
for rearrangement in condition.pattern_reorderings(&self.base_env).into_iter() {
|
||||
let rule = Rule {
|
||||
condition: rearrangement,
|
||||
result: result.clone()
|
||||
};
|
||||
self.current_ruleset.entry(condition.head().unwrap()).or_insert_with(Vec::new).push(rule);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run a single statement (roughly, a line of user input) on the current context
|
||||
fn eval_statement(&mut self, mut stmt: Value) -> Result<Option<Value>> {
|
||||
match stmt {
|
||||
// = sets a binding or generates a new rule.
|
||||
Value::Call(head, args) if head.as_ref() == "=" => {
|
||||
match &args[0] {
|
||||
// Create a binding if the LHS (left hand side) is just an identifier
|
||||
Value::Identifier(id) => {
|
||||
let rhs = self.eval_statement(args[1].clone())?;
|
||||
if let Some(val) = rhs.clone() {
|
||||
self.bindings.insert(id.clone(), val);
|
||||
}
|
||||
Ok(rhs)
|
||||
},
|
||||
// If the LHS is a call, then a rule should be created instead.
|
||||
Value::Call(_head, _args) => {
|
||||
let rhs = self.eval_statement(args[1].clone())?;
|
||||
if let Some(val) = rhs.clone() {
|
||||
self.insert_rule(&args[0], val)?;
|
||||
}
|
||||
Ok(rhs)
|
||||
},
|
||||
// Rebinding numbers can only bring confusion, so it is not allowed.
|
||||
// They also do not have a head, and so cannot be inserted into the ruleset anyway.
|
||||
Value::Num(_) => bail!("You cannot rebind numbers")
|
||||
}
|
||||
},
|
||||
// SetStage[] calls set the stage the current ruleset will be applied at
|
||||
Value::Call(head, args) if head.as_ref() == "SetStage" => {
|
||||
let stage = args[0].assert_ident("SetStage requires an identifier for stage")?;
|
||||
if stage != "all" && None == self.stages.iter().position(|s| s.0 == stage) {
|
||||
bail!("No such stage {}", stage);
|
||||
}
|
||||
self.current_ruleset_stage = stage;
|
||||
Ok(None)
|
||||
},
|
||||
// Move the current ruleset from the "buffer" into the actual list of rules to be applied at each stage
|
||||
Value::Call(head, args) if head.as_ref() == "PushRuleset" => {
|
||||
let name = args[0].assert_ident("PushRuleset requires an identifier for ruleset name")?;
|
||||
// Get ruleset and set the current one to empty
|
||||
let ruleset = std::mem::replace(&mut self.current_ruleset, HashMap::new());
|
||||
// Push ruleset to stages it specifies
|
||||
for (stage_name, stage_rulesets) in self.stages.iter_mut() {
|
||||
if *stage_name == self.current_ruleset_stage || self.current_ruleset_stage == "all" {
|
||||
stage_rulesets.push(name.clone());
|
||||
}
|
||||
}
|
||||
// Insert actual ruleset data under its name
|
||||
self.rulesets.insert(name, Arc::new(ruleset));
|
||||
Ok(None)
|
||||
},
|
||||
// Anything not special should just be repeatedly run through each rewrite stage.
|
||||
_ => {
|
||||
let env = self.base_env.with_bindings(&self.bindings);
|
||||
for (stage_name, stage_rulesets) in self.stages.iter() {
|
||||
// Add relevant rulesets to a new environment for this stage
|
||||
let mut env = env.clone();
|
||||
for ruleset in stage_rulesets.iter() {
|
||||
env = env.with_ruleset(self.rulesets[ruleset].clone());
|
||||
}
|
||||
// Also add the current ruleset if applicable
|
||||
if self.current_ruleset_stage == *stage_name || self.current_ruleset_stage == "all" {
|
||||
env = env.with_ruleset(Arc::new(self.current_ruleset.clone()));
|
||||
}
|
||||
run_rewrite(&mut stmt, &env).with_context(|| format!("failed in {} stage", stage_name))?;
|
||||
// If a ruleset is only meant to be applied in one particular stage, it shouldn't have any later stages applied to it,
|
||||
// or the transformation it's meant to do may be undone
|
||||
if self.current_ruleset_stage == *stage_name {
|
||||
break
|
||||
}
|
||||
}
|
||||
Ok(Some(stmt))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate an entire "program" (multiple statements delineated by ; or newlines)
|
||||
fn eval_program(&mut self, program: &str) -> Result<Option<Value>> {
|
||||
let mut tokens = parse::lex(program)?;
|
||||
let mut last_value = None;
|
||||
loop {
|
||||
// Split at the next break token
|
||||
let remaining_tokens = tokens.iter().position(|x| *x == parse::Token::Break).map(|ix| tokens.split_off(ix + 1));
|
||||
// Trim EOF/break tokens
|
||||
match tokens[tokens.len() - 1] {
|
||||
parse::Token::Break | parse::Token::EOF => tokens.truncate(tokens.len() - 1),
|
||||
_ => ()
|
||||
};
|
||||
// If the statement/line isn't blank, readd EOF for the parser, parse into an AST then Value, and evaluate the statement
|
||||
if tokens.len() > 0 {
|
||||
tokens.push(parse::Token::EOF);
|
||||
let value = Value::from_ast(parse::parse(tokens)?);
|
||||
last_value = self.eval_statement(value)?;
|
||||
}
|
||||
// If there was no break after the current position, this is now done. Otherwise, move onto the new remaining tokens.
|
||||
match remaining_tokens {
|
||||
Some(t) => { tokens = t },
|
||||
None => break
|
||||
}
|
||||
}
|
||||
Ok(last_value)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let mut ctx = ImperativeCtx::init();
|
||||
ctx.eval_program(BUILTINS)?;
|
||||
ctx.eval_program(GENERAL_RULES)?;
|
||||
ctx.eval_program(FACTOR_DEFINITION)?;
|
||||
ctx.eval_program(DENORMALIZATION_RULES)?;
|
||||
ctx.eval_program(NORMALIZATION_RULES)?;
|
||||
ctx.eval_program(DIFFERENTIATION_DEFINITION)?;
|
||||
let stdin = std::io::stdin();
|
||||
for line in stdin.lock().lines() {
|
||||
let line = line?;
|
||||
match ctx.eval_program(&line) {
|
||||
Ok(Some(result)) => println!("{}", result.render_to_string(&ctx.base_env)),
|
||||
Ok(None) => (),
|
||||
Err(e) => println!("Error: {:?}", e)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{ImperativeCtx, BUILTINS, GENERAL_RULES, NORMALIZATION_RULES, DENORMALIZATION_RULES, DIFFERENTIATION_DEFINITION, FACTOR_DEFINITION};
|
||||
|
||||
#[test]
|
||||
fn end_to_end_tests() {
|
||||
let mut ctx = ImperativeCtx::init();
|
||||
ctx.eval_program(BUILTINS).unwrap();
|
||||
ctx.eval_program(GENERAL_RULES).unwrap();
|
||||
ctx.eval_program(FACTOR_DEFINITION).unwrap();
|
||||
ctx.eval_program(DENORMALIZATION_RULES).unwrap();
|
||||
ctx.eval_program(NORMALIZATION_RULES).unwrap();
|
||||
ctx.eval_program(DIFFERENTIATION_DEFINITION).unwrap();
|
||||
let test_cases = [
|
||||
("Factor[x, x*3+x^2]", "(3+x)*x"),
|
||||
("x^a/x^(a+1)", "x^Negate[1]"),
|
||||
("Negate[a+b]", "Negate[1]*b-a"),
|
||||
("Subst[x=4, x+4+4+4+4]", "20"),
|
||||
("(a+b)*(c+d)*(e+f)", "a*c*e+a*c*f+a*d*e+a*d*f+b*c*e+b*c*f+b*d*e+b*d*f"),
|
||||
("(12+55)^3-75+16/(2*2)+5+3*4", "300709"),
|
||||
("D[3*x^3 + 6*x, x] ", "6+9*x^2"),
|
||||
("Fib[n] = Fib[n-1] + Fib[n-2]
|
||||
Fib[0] = 0
|
||||
Fib[1] = 1
|
||||
Fib[6]", "8"),
|
||||
("Subst[b=a, b+a]", "2*a"),
|
||||
("a = 7
|
||||
b = Negate[4]
|
||||
a + b", "3"),
|
||||
("IsEven[x] = 0
|
||||
IsEven[x#Eq[Mod[x, 2], 0]] = 1
|
||||
IsEven[3] - IsEven[4]", "Negate[1]"),
|
||||
("(a+b+c)^2", "2*a*b+2*a*c+2*b*c+a^2+b^2+c^2"),
|
||||
("(x+2)^7", "128+2*x^6+12*x^5+12*x^6+16*x^3+16*x^5+24*x^4+24*x^5+32*x^2+32*x^3+32*x^5+128*x^2+256*x^4+448*x+512*x^2+512*x^3+x^7")
|
||||
];
|
||||
for (input, expected_result) in test_cases {
|
||||
let lhs = ctx.eval_program(input).unwrap();
|
||||
let lhs = lhs.as_ref().unwrap().render_to_string(&ctx.base_env);
|
||||
println!("{} evaluated to {}; expected {}", input, lhs, expected_result);
|
||||
assert_eq!(lhs, expected_result);
|
||||
}
|
||||
|
||||
let error_cases = [
|
||||
("1/0")
|
||||
];
|
||||
|
||||
for error_case in error_cases {
|
||||
if let Err(e) = ctx.eval_program(error_case) {
|
||||
println!("{} produced error {:?}", error_case, e);
|
||||
} else {
|
||||
panic!("should have errored: {}", error_case)
|
||||
}
|
||||
}
|
||||
|
||||
println!("All tests passed.")
|
||||
}
|
||||
}
|
215
src/parse.rs
Normal file
215
src/parse.rs
Normal file
@ -0,0 +1,215 @@
|
||||
use std::str::FromStr;
|
||||
use anyhow::{Result, anyhow, Context};
|
||||
use std::fmt;
|
||||
use inlinable_string::{InlinableString, StringExt};
|
||||
|
||||
use crate::util::char_to_string;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum Token { Number(InlinableString), OpenBracket, CloseBracket, Op(char), EOF, Identifier(InlinableString), OpenSqBracket, CloseSqBracket, Comma, Break }
|
||||
#[derive(Debug)]
|
||||
enum LexState {
|
||||
Number(InlinableString),
|
||||
Identifier(InlinableString),
|
||||
None
|
||||
}
|
||||
|
||||
// lexer
|
||||
// converts input InlinableString to tokens
|
||||
pub fn lex(input: &str) -> Result<Vec<Token>> {
|
||||
let mut toks = vec![];
|
||||
let mut state = LexState::None;
|
||||
for (index, char) in input.chars().enumerate() {
|
||||
state = match (char, state) {
|
||||
// if digit seen, switch into number state (and commit existing one if relevant)
|
||||
('0'..='9', LexState::None) => LexState::Number(char_to_string(char)),
|
||||
('0'..='9', LexState::Identifier(s)) => {
|
||||
toks.push(Token::Identifier(s));
|
||||
LexState::Number(char_to_string(char))
|
||||
},
|
||||
('0'..='9', LexState::Number(mut n)) => {
|
||||
n.push(char);
|
||||
LexState::Number(n)
|
||||
},
|
||||
// if special character seen, commit existing state and push operator/bracket/comma token
|
||||
('#' | '+' | '(' | ')' | '-' | '/' | '*' | '^' | '=' | '[' | ']' | ',', state) => {
|
||||
match state {
|
||||
LexState::Number(s) => toks.push(Token::Number(s)),
|
||||
LexState::Identifier(s) => toks.push(Token::Identifier(s)),
|
||||
_ => ()
|
||||
};
|
||||
toks.push(match char {
|
||||
'(' => Token::OpenBracket, ')' => Token::CloseBracket,
|
||||
'[' => Token::OpenSqBracket, ']' => Token::CloseSqBracket,
|
||||
',' => Token::Comma,
|
||||
a => Token::Op(a)
|
||||
});
|
||||
LexState::None
|
||||
},
|
||||
// semicolon or newline is break
|
||||
(';' | '\n', state) => {
|
||||
match state {
|
||||
LexState::Number(s) => toks.push(Token::Number(s)),
|
||||
LexState::Identifier(s) => toks.push(Token::Identifier(s)),
|
||||
_ => ()
|
||||
};
|
||||
toks.push(Token::Break);
|
||||
LexState::None
|
||||
},
|
||||
// ignore all whitespace
|
||||
(' ', state) => state,
|
||||
// treat all unknown characters as part of identifiers
|
||||
(char, LexState::None) => { LexState::Identifier(char_to_string(char)) },
|
||||
(char, LexState::Identifier(mut s)) => {
|
||||
s.push(char);
|
||||
LexState::Identifier(s)
|
||||
}
|
||||
(char, state) => return Err(anyhow!("got {} in state {:?} (char {})", char, state, index))
|
||||
}
|
||||
}
|
||||
// commit last thing
|
||||
match state {
|
||||
LexState::Number(s) => toks.push(Token::Number(s)),
|
||||
LexState::Identifier(s) => toks.push(Token::Identifier(s)),
|
||||
_ => ()
|
||||
};
|
||||
toks.push(Token::EOF);
|
||||
Ok(toks)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ParseError {
|
||||
Invalid(Token)
|
||||
}
|
||||
|
||||
impl fmt::Display for ParseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ParseError::Invalid(tok) => write!(f, "invalid token {:?}", tok)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::error::Error for ParseError {}
|
||||
|
||||
// parser: contains the sequence of tokens being parsed and current position (index in that)
|
||||
// no lookahead is supported/used
|
||||
struct Parser {
|
||||
tokens: Vec<Token>,
|
||||
position: usize
|
||||
}
|
||||
|
||||
// static table of precedence for each operator
|
||||
pub fn precedence(c: char) -> usize {
|
||||
match c {
|
||||
'=' => 0,
|
||||
'+' => 1,
|
||||
'*' => 2,
|
||||
'-' => 1,
|
||||
'/' => 2,
|
||||
'^' => 3,
|
||||
'#' => 4,
|
||||
c => panic!("invalid operator char {}", c)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_left_associative(c: char) -> bool {
|
||||
match c {
|
||||
'^' => false,
|
||||
_ => true
|
||||
}
|
||||
}
|
||||
|
||||
impl Parser {
|
||||
// Parsing utility functions
|
||||
// Get current token
|
||||
fn current(&self) -> Token { self.tokens[self.position].clone() }
|
||||
// Advance current token
|
||||
fn advance(&mut self) { self.position += 1 }
|
||||
|
||||
// Match current token against predicate and propagate error if it is not matched
|
||||
fn expect<T, F: Fn(&Token) -> Result<T, ()>>(&mut self, pred: F) -> Result<T, ParseError> {
|
||||
let current = self.current();
|
||||
match pred(¤t) {
|
||||
Ok(r) => {
|
||||
self.advance();
|
||||
Ok(r)
|
||||
},
|
||||
Err(()) => Err(ParseError::Invalid(current))
|
||||
}
|
||||
}
|
||||
|
||||
// Parse "leaf" expression: number, arbitrary expression in brackets, identifier, or call (fn[arg1, arg2])
|
||||
fn parse_leaf(&mut self) -> Result<Ast, ParseError> {
|
||||
let res = match self.current() {
|
||||
Token::OpenBracket => {
|
||||
self.advance();
|
||||
let res = self.parse_expr(0)?;
|
||||
self.expect(|t| if *t == Token::CloseBracket { Ok(()) } else { Err(()) })?;
|
||||
res
|
||||
},
|
||||
Token::Number(n) => {
|
||||
let x = i128::from_str(&n).unwrap();
|
||||
self.advance();
|
||||
Ast::Num(x)
|
||||
},
|
||||
Token::Identifier(s) => {
|
||||
self.advance();
|
||||
Ast::Identifier(s)
|
||||
}
|
||||
t => return Err(ParseError::Invalid(t))
|
||||
};
|
||||
// Detection of call syntax: if [ occurs, try and parse comma-separated arguments until a ]
|
||||
if let Token::OpenSqBracket = self.current() {
|
||||
self.advance();
|
||||
let mut args = vec![];
|
||||
loop {
|
||||
args.push(self.parse_expr(0)?);
|
||||
match self.expect(|t| if *t == Token::CloseSqBracket || *t == Token::Comma { Ok(t.clone()) } else { Err(()) })? {
|
||||
Token::CloseSqBracket => break,
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
Ok(Ast::Call(Box::new(res), args))
|
||||
} else {
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse expression including operators, using precedence climbing
|
||||
fn parse_expr(&mut self, prec: usize) -> Result<Ast, ParseError> {
|
||||
let mut result = self.parse_leaf()?;
|
||||
loop {
|
||||
match self.current() {
|
||||
Token::Op(op) if precedence(op) >= prec => {
|
||||
self.advance();
|
||||
let nested_precedence = if is_left_associative(op) { precedence(op) + 1 } else { precedence(op) };
|
||||
let next = self.parse_expr(nested_precedence)?;
|
||||
result = Ast::Op(op, Box::new(result), Box::new(next))
|
||||
},
|
||||
// Break out of loop if operator's precedence is lower, or nonoperator thing encountered
|
||||
// This means that a lower-precedence operator will become part of the tree enclosing the current one
|
||||
_ => break
|
||||
};
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse(t: Vec<Token>) -> Result<Ast> {
|
||||
let mut parser = Parser {
|
||||
tokens: t,
|
||||
position: 0
|
||||
};
|
||||
// Provide slightly more helpful error message indicating the token which isn't valid
|
||||
let result = parser.parse_expr(0).with_context(|| format!("at token {}", parser.position))?;
|
||||
if parser.current() != Token::EOF { return Err(anyhow!("Expected EOF at end of token sequence")) }
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Ast {
|
||||
Num(i128),
|
||||
Identifier(InlinableString),
|
||||
Op(char, Box<Ast>, Box<Ast>),
|
||||
Call(Box<Ast>, Vec<Ast>)
|
||||
}
|
7
src/util.rs
Normal file
7
src/util.rs
Normal file
@ -0,0 +1,7 @@
|
||||
use inlinable_string::{InlinableString, StringExt};
|
||||
|
||||
pub fn char_to_string(c: char) -> InlinableString {
|
||||
let mut s = InlinableString::new();
|
||||
s.push(c);
|
||||
s
|
||||
}
|
217
src/value.rs
Normal file
217
src/value.rs
Normal file
@ -0,0 +1,217 @@
|
||||
use inlinable_string::{InlinableString, StringExt};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::{hash_map::DefaultHasher, HashSet, HashMap};
|
||||
use std::fmt::{self, Write};
|
||||
use itertools::Itertools;
|
||||
use anyhow::{Result, anyhow};
|
||||
|
||||
use crate::parse::{Ast, precedence};
|
||||
use crate::env::{Env, Bindings};
|
||||
use crate::util::char_to_string;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)]
|
||||
pub enum Value {
|
||||
Num(i128),
|
||||
Call(InlinableString, Vec<Value>),
|
||||
Identifier(InlinableString),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
// Converts an AST from `parse.rs` to a Value
|
||||
pub fn from_ast(ast: Ast) -> Self {
|
||||
match ast {
|
||||
Ast::Op(char, t1, t2) => Value::Call(char_to_string(char), vec![Value::from_ast(*t1), Value::from_ast(*t2)]),
|
||||
Ast::Call(f, args) => {
|
||||
Value::Call(match *f {
|
||||
Ast::Identifier(n) => n,
|
||||
_ => unimplemented!()
|
||||
}, args.into_iter().map(Value::from_ast).collect())
|
||||
},
|
||||
Ast::Num(n) => Value::Num(n),
|
||||
Ast::Identifier(i) => Value::Identifier(i)
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the hash of a Value
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
// according to https://doc.rust-lang.org/std/collections/hash_map/struct.DefaultHasher.html, all instances created here are guaranteed to be the same
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// Gets the head (string at start of call expression)
|
||||
pub fn head(&self) -> Option<InlinableString> {
|
||||
match self {
|
||||
Value::Call(fun, _) => Some(fun.clone()),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
|
||||
// Replace variables with incremental IDs, vaguely like de Bruijn indices in lambda calculus.
|
||||
// Allows patterns to be compared regardless of the names of the identifiers they contain.
|
||||
fn canonicalize_variables(&self) -> (Self, Bindings) {
|
||||
fn go(v: &Value, bindings: &mut Bindings, counter: &mut usize) -> Value {
|
||||
match v {
|
||||
Value::Num(_) => v.clone(),
|
||||
Value::Identifier(name) => {
|
||||
match bindings.get(name) {
|
||||
Some(id) => id.clone(),
|
||||
None => {
|
||||
let mut next_id = InlinableString::new();
|
||||
write!(next_id, "{:X}", counter).unwrap();
|
||||
let next_id = Value::Identifier(next_id);
|
||||
*counter += 1;
|
||||
bindings.insert(name.clone(), next_id.clone());
|
||||
next_id
|
||||
}
|
||||
}
|
||||
},
|
||||
Value::Call(head, args) => {
|
||||
Value::Call(head.clone(), args.iter().map(|x| go(x, bindings, counter)).collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut vars = HashMap::new();
|
||||
let mut ctr = 0;
|
||||
(go(self, &mut vars, &mut ctr), vars)
|
||||
}
|
||||
|
||||
// Hash the canonical-variables form of a pattern. Allows patterns to be checked for equality safely.
|
||||
fn pattern_hash(&self) -> u64 {
|
||||
self.canonicalize_variables().0.get_hash()
|
||||
}
|
||||
|
||||
// Generate all possible ways a pattern can be ordered, given commutative operators it may contain.
|
||||
// This also recurses into child nodes.
|
||||
pub fn pattern_reorderings(&self, env: &Env) -> Vec<Self> {
|
||||
// Filter out redundant patterns from a result, and convert the argument lists back into Values
|
||||
// Due to typing, this has to be factored out into a separate generic function
|
||||
// rather than being part of the main logic below.
|
||||
fn map_result<I: Iterator<Item=Vec<Value>>>(head: &InlinableString, result: I) -> Vec<Value> {
|
||||
let mut existing_patterns = HashSet::new();
|
||||
result.flat_map(|x| {
|
||||
let resulting_value = Value::Call(head.clone(), x);
|
||||
let hash = resulting_value.pattern_hash();
|
||||
if existing_patterns.contains(&hash) {
|
||||
None
|
||||
} else {
|
||||
existing_patterns.insert(hash);
|
||||
Some(resulting_value)
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
|
||||
match self {
|
||||
// Call expressions can have their child nodes reordered and can be reordered themselves, if the head is a commutative operator
|
||||
Value::Call(head, args) => {
|
||||
let result = args.iter()
|
||||
// Recursive step: generate all the valid reorderings of each child node.
|
||||
.map(|x| x.pattern_reorderings(env))
|
||||
// Generate all possible combinations of those reorderings.
|
||||
.multi_cartesian_product();
|
||||
// Generate all possible permutations of those combinations, if the operation allows for this.
|
||||
if env.get_op(head).commutative {
|
||||
map_result(head, result.flat_map(|comb| {
|
||||
let k = comb.len();
|
||||
comb.into_iter().permutations(k)
|
||||
}))
|
||||
} else {
|
||||
map_result(head, result)
|
||||
}
|
||||
},
|
||||
// Any other expression type is not reorderable.
|
||||
_ => {
|
||||
vec![self.clone()]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Substitute bindings into an expression tree;
|
||||
// the main match_and_bind function can also do this, but doing it here is more efficient
|
||||
// when its full power is not needed
|
||||
pub fn subst(&self, bindings: &Bindings) -> Value {
|
||||
match self {
|
||||
Value::Identifier(name) => {
|
||||
match bindings.get(name) {
|
||||
Some(value) => value.clone(),
|
||||
None => Value::Identifier(name.clone())
|
||||
}
|
||||
},
|
||||
Value::Call(fun, args) => Value::Call(fun.clone(), args.iter().map(|x| x.subst(bindings)).collect()),
|
||||
x => x.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that a value is a number, returning an error otherwise.
|
||||
pub fn assert_num(&self, ctx: &'static str) -> Result<i128> {
|
||||
match self {
|
||||
Value::Num(n) => Ok(*n),
|
||||
_ => Err(anyhow!("expected number, got {:?}", self).context(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// The same but for identfiers.
|
||||
pub fn assert_ident(&self, ctx: &'static str) -> Result<InlinableString> {
|
||||
match self {
|
||||
Value::Identifier(i) => Ok(i.clone()),
|
||||
_ => Err(anyhow!("expected identifier, got {:?}", self).context(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render<W: fmt::Write>(&self, f: &mut W, env: &Env) -> fmt::Result {
|
||||
fn go<W: fmt::Write>(v: &Value, parent_prec: Option<usize>, env: &Env, f: &mut W) -> fmt::Result {
|
||||
match v {
|
||||
// As unary - isn't parsed, negative numbers are written with Negate instead.
|
||||
Value::Num(x) => if *x >= 0 {
|
||||
write!(f, "{}", x)
|
||||
} else { write!(f, "Negate[{}]", -x) },
|
||||
Value::Identifier(i) => write!(f, "{}", i),
|
||||
Value::Call(head, args) => {
|
||||
match env.ops.get(head) {
|
||||
Some(_) => {
|
||||
// If the precedence of the enclosing operation is greater than or equal to this one's,
|
||||
// add brackets around this one
|
||||
let this_prec = precedence(head.chars().next().unwrap());
|
||||
let render_brackets = match parent_prec {
|
||||
Some(prec) => prec >= this_prec,
|
||||
None => false
|
||||
};
|
||||
if render_brackets {
|
||||
write!(f, "(")?;
|
||||
}
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
go(arg, Some(this_prec), env, f)?;
|
||||
if i + 1 != args.len() {
|
||||
write!(f, "{}", head)?;
|
||||
}
|
||||
}
|
||||
if render_brackets {
|
||||
write!(f, ")")?;
|
||||
}
|
||||
},
|
||||
// Just write a call expression with square brackets.
|
||||
None => {
|
||||
write!(f, "{}[", head)?;
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
go(arg, None, env, f)?;
|
||||
if i + 1 != args.len() {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
go(self, None, env, f)
|
||||
}
|
||||
|
||||
pub fn render_to_string(&self, env: &Env) -> InlinableString {
|
||||
let mut out = InlinableString::new();
|
||||
self.render(&mut out, env).unwrap();
|
||||
out
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user