Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing Operator types #84

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ using Catlab
export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
# Deca
op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D,
op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators,
recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
## collages
Collage, collate,
Expand All @@ -18,8 +17,7 @@ apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
resolve_overloads!, replace_names!,
apply_inference_rule_op1!, apply_inference_rule_op2!,
resolve_overloads!, replace_names!, type_check,
transfer_parents!, transfer_children!,
unique_lits!,
## language
Expand All @@ -32,7 +30,9 @@ to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!,
Operator, same_type_rules_op, arthimetic_operators, infer_resolve!, type_check, DecaTypeExeception


using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
Expand Down
216 changes: 160 additions & 56 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using ACSets.InterTypes

@intertypes "decapodeacset.it" module decapodeacset end

import Base.show

using .decapodeacset

# Transferring pointers
Expand Down Expand Up @@ -362,7 +364,7 @@ function find_chains(d::SummationDecapode;

filter!(x -> passes_white_list(d[x, :op1]), chain_starts)
filter!(x -> passes_black_list(d[x, :op1]), chain_starts)

s = Stack{Int64}()
foreach(x -> push!(s, x), chain_starts)
while !isempty(s)
Expand Down Expand Up @@ -440,6 +442,50 @@ function filterfor_ec_types(types::AbstractVector{Symbol})
filter(conditions, types)
end

struct Operator{T}
res_type::T
src_types::AbstractVector{T}
op_name::Symbol
aliases::AbstractVector{Symbol}

function Operator{T}(res_type::T, src_types::AbstractVector{T}, op_name, aliases = Symbol[]) where T
new(res_type, src_types, op_name, aliases)
end

function Operator{T}(res_type::T, src_type::T, op_name, aliases = Symbol[]) where T
new(res_type, T[src_type], op_name, aliases)
end

function Operator(res_type::Symbol, src_type::Union{Symbol, AbstractVector{Symbol}}, op_name, aliases = Symbol[])
Operator{Symbol}(res_type, src_type, op_name, aliases)
end
end

function same_type_rules_op(op_name::Symbol, types::AbstractVector{Symbol}, arity::Int, g_aliases::AbstractVector{Symbol} = Symbol[], sp_aliases::AbstractVector = Symbol[])
@assert isempty(sp_aliases) || length(types) == length(sp_aliases)
map(1:length(types)) do i
aliases = isempty(sp_aliases) ? g_aliases : vcat(g_aliases, sp_aliases[i])
Operator{Symbol}(types[i], repeat([types[i]], arity), op_name, aliases)
end
end

function arthimetic_operators(op_name::Symbol, broadcasted::Bool, arity::Int = 2)
@match (broadcasted, arity) begin
(true, 2) => bin_broad_arith_ops(op_name)
_ => error("This type of arthimetic operator is not yet supported or may not be valid.")
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
end
end

function bin_broad_arith_ops(op_name)
all_ops = map(t -> Operator{Symbol}(t, [t, t], op_name), FORM_TYPES)
for type in vcat(USER_TYPES, NUMBER_TYPES)
append!(all_ops, map(t -> Operator{Symbol}(t, [t, type], op_name), FORM_TYPES))
append!(all_ops, map(t -> Operator{Symbol}(t, [type, t], op_name), FORM_TYPES))
end

all_ops
end

function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
# Note that we are not doing any type checking here for users!
# i.e. We are not checking the underlying types of Constant or Parameter
Expand All @@ -466,36 +512,99 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])
function check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false)
inputs = edge_inputs(d, op_id, edge_val)
output = edge_output(d, op_id, edge_val)

max_score = length(inputs) + length(output)

rule_types = vcat(rule.src_types, rule.res_type)
deca_types = vcat(d[inputs, :type], d[output, :type])

score = mapreduce(+, zip(rule_types, deca_types); init = 0) do (rule_t, deca_t)
if ignore_infers && deca_t in INFER_TYPES
return 1
elseif ignore_usertypes && deca_t in USER_TYPES
return 1
else
return rule_t == deca_t
end
end

dop_name = edge_function(d, op_id, edge_val)

named = check_name && dop_name == rule.op_name
aliased = check_aliases && dop_name in rule.aliases

(named || aliased, max_score - score)
end

function apply_inference_rule!(d::SummationDecapode, op_id, rule, edge_val)

name_present, type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true)

check_op = (d[op1_id, :op1] in rule.op_names)
if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
return mod_src || mod_tgt
if name_present && type_diff == 1
vars = vcat(edge_inputs(d, op_id, edge_val), edge_output(d, op_id, edge_val))
types = vcat(rule.src_types, rule.res_type)
return any(map(zip(vars, types)) do (var, type)
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
safe_modifytype!(d, var, type)
end)
end

return false
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])
function apply_overloading_rule!(d::SummationDecapode, op_id, rule, edge_val)

check_op = (d[op2_id, :op2] in rule.op_names)
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
name_present, type_diff = check_operator(d, op_id, rule, edge_val; check_aliases = true)

if name_present && type_diff == 0
set_edge_label!(d, op_id, rule.op_name, edge_val)
return true
end

return false
end

struct DecaTypeError{T}
rule::Operator{T}
idx::Int
table::Symbol
end

Base.show(io::IO, type_error::DecaTypeError{T}) where T = println("Operator at index $(type_error.idx) in table $(type_error.table) is not correctly typed. Perhaps the operator was meant to be $(type_error.rule)?")

struct DecaTypeExeception{T} <: Exception
type_errors::Vector{DecaTypeError{T}}
end

function Base.show(io::IO, type_except::DecaTypeExeception{T}) where T
map(x -> Base.show(io, x), type_except.type_errors)
end

function run_typechecking(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})

type_errors = DecaTypeError{Symbol}[]

for table in [:Op1, :Op2]
for op_idx in parts(d, table)
type_error = run_typechecking_for_op(d, op_idx, type_rules, Val(table))
if type_error !== nothing
push!(type_errors, type_error)
end
end
end

return type_errors
end

function run_typechecking_for_op(d::SummationDecapode, op_id, type_rules, edge_val::Val{table}) where table
min_diff, min_rule_idx = findmin(type_rules) do rule
name_present, type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true, ignore_infers = true, ignore_usertypes = true)
name_present ? type_diff : Inf
end
min_diff in [0,Inf] ? nothing : DecaTypeError{Symbol}(type_rules[min_rule_idx], op_id, table)
end

# TODO: Although the big-O complexity is the same, it might be more efficent on
# average to iterate over edges then rules, instead of rules then edges. This
Expand All @@ -506,7 +615,7 @@ end

Infer types of Vars given rules wherein one type is known and the other not.
"""
function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :op_names), Tuple{Symbol, Symbol, Vector{Symbol}}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :op_names), Tuple{Symbol, Symbol, Symbol, Vector{Symbol}}}})
function infer_types!(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})

# This is an optimization so we do not "visit" a row which has no infer types.
# It could be deleted if found to be not worth maintainability tradeoff.
Expand All @@ -519,28 +628,23 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
types_known_op2[incident(d, :infer, [:proj2, :type])] .= false
types_known_op2[incident(d, :infer, [:res, :type])] .= false

types_known = Dict{Symbol, Vector{Bool}}(:Op1 => types_known_op1, :Op2 => types_known_op2)

while true
applied = false

for rule in op1_rules
for op1_idx in parts(d, :Op1)
types_known_op1[op1_idx] && continue

this_applied = apply_inference_rule_op1!(d, op1_idx, rule)

types_known_op1[op1_idx] = this_applied
applied |= this_applied
end
end
for table in [:Op1, :Op2]
for op_idx in parts(d, table)
types_known[table][op_idx] && continue

for rule in op2_rules
for op2_idx in parts(d, :Op2)
types_known_op2[op2_idx] && continue
for rule in type_rules
this_applied = apply_inference_rule!(d, op_idx, rule, Val(table))

this_applied = apply_inference_rule_op2!(d, op2_idx, rule)
types_known[table][op_idx] = this_applied
applied |= this_applied
end

types_known_op2[op2_idx] = this_applied
applied |= this_applied
end
end

Expand All @@ -554,38 +658,38 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
d
end



""" function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}})

Resolve function overloads based on types of src and tgt.
"""
function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :resolved_name, :op), NTuple{5, Symbol}}})
for op1_idx in parts(d, :Op1)
src = d[:src][op1_idx]; tgt = d[:tgt][op1_idx]; op1 = d[:op1][op1_idx]
src_type = d[:type][src]; tgt_type = d[:type][tgt]
for rule in op1_rules
if op1 == rule[:op] && src_type == rule[:src_type] && tgt_type == rule[:tgt_type]
d[op1_idx, :op1] = rule[:resolved_name]
break
end
end
end

for op2_idx in parts(d, :Op2)
proj1 = d[:proj1][op2_idx]; proj2 = d[:proj2][op2_idx]; res = d[:res][op2_idx]; op2 = d[:op2][op2_idx]
proj1_type = d[:type][proj1]; proj2_type = d[:type][proj2]; res_type = d[:type][res]
for rule in op2_rules
if op2 == rule[:op] && proj1_type == rule[:proj1_type] && proj2_type == rule[:proj2_type] && res_type == rule[:res_type]
d[op2_idx, :op2] = rule[:resolved_name]
break
function resolve_overloads!(d::SummationDecapode, resolve_rules::AbstractVector{Operator{Symbol}})
for rule in resolve_rules
for table in [:Op1, :Op2]
for op_idx in parts(d, table)
apply_overloading_rule!(d, op_idx, rule, Val(table))
end
end
end

d
end

function type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
type_errors = run_typechecking(d, type_rules)

isempty(type_errors) && return true

throw(DecaTypeExeception{Symbol}(type_errors))
return false
end

function infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}})
infer_types!(d, operators)
resolve_overloads!(d, operators)
type_check(d, operators)

d
end

function replace_names!(d::SummationDecapode, op1_repls::Vector{Pair{Symbol, Any}}, op2_repls::Vector{Pair{Symbol, Symbol}})
for (orig,repl) in op1_repls
Expand Down
6 changes: 4 additions & 2 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ using Catlab

using Reexport

import ..infer_types!, ..resolve_overloads!
import ..infer_types!, ..resolve_overloads!, ..type_check, ..infer_resolve!

export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!
export normalize_unicode, varname, infer_types!, resolve_overloads!, type_check, infer_resolve!,
typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, vec_to_dec!,
op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators

include("deca_acset.jl")
include("deca_visualization.jl")
Expand Down
Loading
Loading