Fuzzy search, more generalized result types

This commit is contained in:
osmarks 2018-08-14 21:16:27 +01:00
parent fc59537468
commit 998e77fa69
3 changed files with 80 additions and 10 deletions

View File

@ -3,6 +3,7 @@
local w = require "lib" local w = require "lib"
local d = require "luadash" local d = require "luadash"
local fuzzy_match = require "fuzzy"
local conf = w.load_config({ local conf = w.load_config({
"buffer_internal", "buffer_internal",
@ -96,12 +97,9 @@ local function find_by_ID_meta_NBT(ID, meta, NBT_hash)
end end
local function search(query, threshold) local function search(query, threshold)
local threshold = threshold or 4
local results = find(function(item) local results = find(function(item)
local distance = d.distance(string.lower(query), string.lower(item.display_name)) local match, best_start = fuzzy_match(item.display_name, query)
if distance < threshold then if best_start ~= nil and match > 0 then return true, match end
return true, distance
else return false end
end) end)
return d.sort_by(results, function(x) return x.extra end) -- sort returned results by closeness to query return d.sort_by(results, function(x) return x.extra end) -- sort returned results by closeness to query
end end

66
fuzzy.lua Normal file
View File

@ -0,0 +1,66 @@
-- Squid's fuzzy search thing
-- https://github.com/SquidDev-CC/artist/blob/vnext/artist/lib/match.lua
local score_weight = 1000
local adjacency_bonus = 5
local leading_letter_penalty = -3
local leading_letter_penalty_max = -9
local unmatched_letter_penalty = -1
local function match_simple(str, ptrn)
local best_score, best_start = 0, nil
-- Trim the two strings
ptrn = ptrn:gsub("^ *", ""):gsub(" *$", "")
str = str:gsub("^ *", ""):gsub(" *$", "")
local str_lower = str:lower()
local ptrn_lower = ptrn:lower()
local start = 1
while true do
-- Find a location where the first character matches
start = str_lower:find(ptrn_lower:sub(1, 1), start, true)
if not start then break end
-- All letters before the current one are considered leading, so add them to our penalty
local score = score_weight + math.max(leading_letter_penalty * (start - 1), leading_letter_penalty_max)
local previous_match = true
-- We now walk through each pattern character and attempt to determine if they match
local str_pos, ptrn_pos = start + 1, 2
while str_pos <= #str and ptrn_pos <= #ptrn do
local ptrn_char = ptrn_lower:sub(ptrn_pos, ptrn_pos)
local str_char = str_lower:sub(str_pos, str_pos)
if ptrn_char == str_char then
-- If we've got multiple adjacent matches then give bonus points
if previous_match then score = score + adjacency_bonus end
previous_match = true
ptrn_pos = ptrn_pos + 1
else
-- If we don't match a letter then minus points
score = score + unmatched_letter_penalty
previous_match = false
end
str_pos = str_pos + 1
end
-- If we've matched the entire pattern then consider us as a candidate
if ptrn_pos > #ptrn and score > best_score then
best_score = score
best_start = start
end
start = start + 1
end
return best_score, best_start
end
return match_simple

16
lib.lua
View File

@ -126,7 +126,7 @@ local function serve(fn, node_type)
local end_time = os.clock() local end_time = os.clock()
print("Response:", textutils.serialise(result)) print("Response:", textutils.serialise(result))
print("Time:", string.format("%.1f", end_time - start_time)) print("Time:", string.format("%.1f", end_time - start_time))
response = { type = "response", response = result } response = { type = "OK", value = result }
end end
else else
print("Request Invalid") print("Request Invalid")
@ -252,7 +252,7 @@ local function init()
d.map(find_peripherals(function(type, name, wrapped) return type == "modem" end), function(p) rednet.open(p.name) end) d.map(find_peripherals(function(type, name, wrapped) return type == "modem" end), function(p) rednet.open(p.name) end)
end end
-- Rust-style unwrap. If x is a response type, will take out its contents and return them - if error, will crash and print it, with msg if provided -- Rust-style unwrap. If x is an OK table, will take out its contents and return them - if error, will crash and print it, with msg if provided
local function unwrap(x, msg) local function unwrap(x, msg)
if not x or type(x) ~= "table" or not x.type then x = errors.make(errors.INTERNAL, "Error/response object is invalid. This is probably a problem with the node being contacted.") end if not x or type(x) ~= "table" or not x.type then x = errors.make(errors.INTERNAL, "Error/response object is invalid. This is probably a problem with the node being contacted.") end
@ -262,9 +262,15 @@ local function unwrap(x, msg)
else text = text .. "!" end else text = text .. "!" end
text = text .. ".\nDetails: " .. errors.format(x) text = text .. ".\nDetails: " .. errors.format(x)
error(text) error(text)
elseif x.type == "response" then elseif x.type == "OK" then
return x.response return x.value
end end
end end
return { errors = errors, serve = serve, query_by_ID = query_by_ID, query_by_type = query_by_type, unwrap = unwrap, to_wyvern_item = to_wyvern_item, get_internal_identifier = get_internal_identifier, load_config = load_config, find_peripherals = find_peripherals, init = init, collate = collate, satisfied = satisfied, collate_stacks = collate_stacks } -- Wrap x in an OK result
local function make_OK(x)
return { type = "OK", value = x }
end
-- TODO: Not do this
return { errors = errors, serve = serve, query_by_ID = query_by_ID, query_by_type = query_by_type, unwrap = unwrap, to_wyvern_item = to_wyvern_item, get_internal_identifier = get_internal_identifier, load_config = load_config, find_peripherals = find_peripherals, init = init, collate = collate, satisfied = satisfied, collate_stacks = collate_stacks, make_error = errors.make, make_OK = make_OK }