Skip to content

Commit

Permalink
Template Inference with basic unification. Remove ConcreteTemplateArg
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Jan 30, 2025
1 parent 9ac82ae commit ca04e50
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 219 deletions.
18 changes: 9 additions & 9 deletions src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl<IndexMarker> UUIDRange<IndexMarker> {
pub fn iter(&self) -> UUIDRangeIter<IndexMarker> {
self.into_iter()
}
pub fn map<OT, F: FnMut(UUID<IndexMarker>) -> OT>(&self, f: F) -> FlatAlloc<OT, IndexMarker> {
pub fn map<OT>(&self, f: impl FnMut(UUID<IndexMarker>) -> OT) -> FlatAlloc<OT, IndexMarker> {
FlatAlloc {
data: Vec::from_iter(self.iter().map(f)),
_ph: PhantomData,
Expand Down Expand Up @@ -302,9 +302,9 @@ impl<T, IndexMarker> ArenaAllocator<T, IndexMarker> {
pub fn iter_mut<'a>(&'a mut self) -> FlatOptionIteratorMut<'a, T, IndexMarker> {
self.into_iter()
}
pub fn find<F: FnMut(UUID<IndexMarker>, &T) -> bool>(
pub fn find(
&self,
mut predicate: F,
mut predicate: impl FnMut(UUID<IndexMarker>, &T) -> bool,
) -> Option<UUID<IndexMarker>> {
self.iter()
.find(|(id, v)| predicate(*id, v))
Expand Down Expand Up @@ -428,9 +428,9 @@ impl<T, IndexMarker> ArenaVector<T, IndexMarker> {
pub fn iter_mut<'a>(&'a mut self) -> FlatOptionIteratorMut<'a, T, IndexMarker> {
self.into_iter()
}
pub fn find<F: FnMut(UUID<IndexMarker>, &T) -> bool>(
pub fn find(
&self,
mut predicate: F,
mut predicate: impl FnMut(UUID<IndexMarker>, &T) -> bool,
) -> Option<UUID<IndexMarker>> {
self.iter()
.find(|(id, v)| predicate(*id, v))
Expand Down Expand Up @@ -537,18 +537,18 @@ impl<T, IndexMarker> FlatAlloc<T, IndexMarker> {
pub fn iter_mut<'a>(&'a mut self) -> FlatAllocIterMut<'a, T, IndexMarker> {
self.into_iter()
}
pub fn map<OT, F: FnMut((UUID<IndexMarker>, &T)) -> OT>(
pub fn map<OT>(
&self,
f: F,
f: impl FnMut((UUID<IndexMarker>, &T)) -> OT,
) -> FlatAlloc<OT, IndexMarker> {
FlatAlloc {
data: Vec::from_iter(self.iter().map(f)),
_ph: PhantomData,
}
}
pub fn find<F: FnMut(UUID<IndexMarker>, &T) -> bool>(
pub fn find(
&self,
mut predicate: F,
mut predicate: impl FnMut(UUID<IndexMarker>, &T) -> bool,
) -> Option<UUID<IndexMarker>> {
self.iter()
.find(|(id, v)| predicate(*id, v))
Expand Down
10 changes: 5 additions & 5 deletions src/codegen/system_verilog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::flattening::{DeclarationKind, Instruction, Module, Port};
use crate::instantiation::{
InstantiatedModule, RealWire, RealWireDataSource, RealWirePathElem, CALCULATE_LATENCY_LATER,
};
use crate::typing::template::{ConcreteTemplateArg, TVec};
use crate::typing::template::TVec;
use crate::{typing::concrete_type::ConcreteType, value::Value};

use super::shared::*;
Expand Down Expand Up @@ -333,21 +333,21 @@ impl<'g> CodeGenerationContext<'g> {
fn write_template_args(
&mut self,
link_info: &LinkInfo,
concrete_template_args: &TVec<ConcreteTemplateArg>,
concrete_template_args: &TVec<ConcreteType>,
) {
self.program_text.write_str(&link_info.name).unwrap();
self.program_text.write_str(" #(").unwrap();
let mut first = true;
concrete_template_args.iter().for_each(|(arg_id, arg)| {
let arg_name = &link_info.template_parameters[arg_id].name;
let arg_value = match arg {
ConcreteTemplateArg::Type(..) => {
ConcreteType::Named(..) | ConcreteType::Array(..) => {
unreachable!("No extern module type arguments. Should have been caught by Lint")
}
ConcreteTemplateArg::Value(value, _) => {
ConcreteType::Value(value) => {
value.inline_constant_to_string()
}
ConcreteTemplateArg::NotProvided => unreachable!("All args are known at codegen"),
ConcreteType::Unknown(_) => unreachable!("All args are known at codegen"),
};
if first {
self.program_text.write_char(',').unwrap();
Expand Down
18 changes: 9 additions & 9 deletions src/flattening/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ impl<'t> Cursor<'t> {
}

#[track_caller]
pub fn go_down<OT, F: FnOnce(&mut Self) -> OT>(&mut self, kind: u16, func: F) -> OT {
pub fn go_down<OT>(&mut self, kind: u16, func: impl FnOnce(&mut Self) -> OT) -> OT {
self.assert_is_kind(kind);

self.go_down_no_check(func)
}

pub fn go_down_no_check<OT, F: FnOnce(&mut Self) -> OT>(&mut self, func: F) -> OT {
pub fn go_down_no_check<OT>(&mut self, func: impl FnOnce(&mut Self) -> OT) -> OT {
if !self.cursor.goto_first_child() {
self.print_stack();
panic!("Could not go down this node!");
Expand All @@ -219,7 +219,7 @@ impl<'t> Cursor<'t> {

/// Goes down the current node, checks it's kind, and then iterates through 'item' fields.
#[track_caller]
pub fn list<F: FnMut(&mut Self)>(&mut self, parent_kind: u16, mut func: F) {
pub fn list(&mut self, parent_kind: u16, mut func: impl FnMut(&mut Self)) {
self.assert_is_kind(parent_kind);

if self.cursor.goto_first_child() {
Expand Down Expand Up @@ -251,10 +251,10 @@ impl<'t> Cursor<'t> {
///
/// The function given should return OT, and from the valid outputs this function constructs a output list
#[track_caller]
pub fn collect_list<OT, F: FnMut(&mut Self) -> OT>(
pub fn collect_list<OT>(
&mut self,
parent_kind: u16,
mut func: F,
mut func: impl FnMut(&mut Self) -> OT,
) -> Vec<OT> {
let mut result = Vec::new();

Expand All @@ -268,10 +268,10 @@ impl<'t> Cursor<'t> {

/// Goes down the current node, checks it's kind, and then selects the 'content' field. Useful for constructs like seq('[', field('content', $.expr), ']')
#[track_caller]
pub fn go_down_content<OT, F: FnOnce(&mut Self) -> OT>(
pub fn go_down_content<OT>(
&mut self,
parent_kind: u16,
func: F,
func: impl FnOnce(&mut Self) -> OT,
) -> OT {
self.go_down(parent_kind, |self2| {
self2.field(field!("content"));
Expand Down Expand Up @@ -361,11 +361,11 @@ impl<'t> Cursor<'t> {

/// Goes down the current node, checks it's kind, and then iterates through 'item' fields.
#[track_caller]
pub fn list_and_report_errors<F: FnMut(&mut Self)>(
pub fn list_and_report_errors(
&mut self,
parent_kind: u16,
errors: &ErrorCollector,
mut func: F,
mut func: impl FnMut(&mut Self),
) {
self.assert_is_kind(parent_kind);
if self.cursor.goto_first_child() {
Expand Down
124 changes: 69 additions & 55 deletions src/instantiation/concrete_typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
use std::ops::Deref;

use crate::errors::ErrorInfoObject;
use crate::flattening::{DeclarationKind, WireReferenceRoot, ExpressionSource, WrittenType};
use crate::flattening::{DeclarationKind, ExpressionSource, WireReferenceRoot, WrittenType};
use crate::linker::LinkInfo;
use crate::typing::template::{ConcreteTemplateArg, HowDoWeKnowTheTemplateArg};
use crate::typing::concrete_type::ConcreteGlobalReference;
use crate::typing::template::TemplateArgKind;
use crate::typing::{
concrete_type::{ConcreteType, BOOL_CONCRETE_TYPE, INT_CONCRETE_TYPE},
type_inference::{FailedUnification, DelayedConstraint, DelayedConstraintStatus, DelayedConstraintsList},
Expand Down Expand Up @@ -137,7 +138,20 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {

pub fn typecheck(&mut self) {
let mut delayed_constraints : DelayedConstraintsList<Self> = DelayedConstraintsList::new();
for (sm_id, _sm) in &self.submodules {
for (sm_id, sm) in &self.submodules {
let sub_module = &self.linker.modules[sm.module_uuid];

for (port_id, p) in sm.port_map.iter_valids() {
let wire = &self.wires[p.maps_to_wire];

let port_decl_instr = sub_module.ports[port_id].declaration_instruction;
let port_decl = sub_module.link_info.instructions[port_decl_instr].unwrap_declaration();

let typ_for_inference = concretize_written_type_with_possible_template_args(&port_decl.typ_expr, &sm.template_args, &sub_module.link_info, &self.type_substitutor);

self.type_substitutor.unify_must_succeed(&wire.typ, &typ_for_inference);
}

delayed_constraints.push(SubmoduleTypecheckConstraint {sm_id});
}

Expand Down Expand Up @@ -173,76 +187,76 @@ fn can_expression_be_value_inferred(link_info: &LinkInfo, expr_id: FlatID) -> Op
Some(*template_id)
}

fn try_to_attach_value_to_template_arg(template_wire_referernce: FlatID, found_value: &ConcreteType, template_args: &mut TVec<ConcreteTemplateArg>, submodule_link_info: &LinkInfo) {
let ConcreteType::Value(v) = found_value else {return}; // We don't have a value to assign
if let Some(template_id) = can_expression_be_value_inferred(submodule_link_info, template_wire_referernce) {
if let ConcreteTemplateArg::NotProvided = &template_args[template_id] {
template_args[template_id] = ConcreteTemplateArg::Value(v.clone(), HowDoWeKnowTheTemplateArg::Inferred)
}
}
}

fn infer_parameters_by_walking_type(port_wr_typ: &WrittenType, connected_typ: &ConcreteType, template_args: &mut TVec<ConcreteTemplateArg>, submodule_link_info: &LinkInfo) {
match port_wr_typ {
WrittenType::Error(_) => {} // Can't continue, bad written type
WrittenType::Named(_) => {} // Seems we've run out of type to check
WrittenType::Array(_span, written_arr_box) => {
let ConcreteType::Array(concrete_arr_box) = connected_typ else {return}; // Can't continue, type not worked out. TODO should we seed concrete types with derivates from AbstractTypes?
let (written_arr, written_size_var, _) = written_arr_box.deref();
let (concrete_arr, concrete_size) = concrete_arr_box.deref();

infer_parameters_by_walking_type(written_arr, concrete_arr, template_args, submodule_link_info); // Recurse down
fn concretize_written_type_with_possible_template_args(
written_typ: &WrittenType,
template_args: &TVec<ConcreteType>,
link_info: &LinkInfo,
type_substitutor: &TypeSubstitutor<ConcreteType, ConcreteTypeVariableIDMarker>
) -> ConcreteType {
match written_typ {
WrittenType::Error(_span) => ConcreteType::Unknown(type_substitutor.alloc()),
WrittenType::TemplateVariable(_span, uuid) => template_args[*uuid].clone(),
WrittenType::Named(global_reference) => {
let object_template_args : TVec<ConcreteType> = global_reference.template_args.map(|(_arg_id, arg)| -> ConcreteType {
if let Some(arg) = arg {
match &arg.kind {
TemplateArgKind::Type(arg_wr_typ) => {
concretize_written_type_with_possible_template_args(arg_wr_typ, template_args, link_info, type_substitutor)
}
TemplateArgKind::Value(uuid) => {
if let Some(found_template_arg) = can_expression_be_value_inferred(link_info, *uuid) {
template_args[found_template_arg].clone()
} else {
ConcreteType::Unknown(type_substitutor.alloc())
}
}
}
} else {
ConcreteType::Unknown(type_substitutor.alloc())
}
});

try_to_attach_value_to_template_arg(*written_size_var, concrete_size, template_args, submodule_link_info); // Potential place for template inference!
ConcreteType::Named(ConcreteGlobalReference{
id: global_reference.id,
template_args: object_template_args
})
}
WrittenType::TemplateVariable(_span, template_id) => {
if !connected_typ.contains_unknown() {
if let ConcreteTemplateArg::NotProvided = &template_args[*template_id] {
template_args[*template_id] = ConcreteTemplateArg::Type(connected_typ.clone(), HowDoWeKnowTheTemplateArg::Inferred)
}
}
WrittenType::Array(_span, arr_box) => {
let (arr_content_wr, arr_idx_id, _arr_brackets) = arr_box.deref();

let arr_content_concrete = concretize_written_type_with_possible_template_args(arr_content_wr, template_args, link_info, type_substitutor);
let arr_idx_concrete = if let Some(found_template_arg) = can_expression_be_value_inferred(link_info, *arr_idx_id) {
template_args[found_template_arg].clone()
} else {
ConcreteType::Unknown(type_substitutor.alloc())
};

ConcreteType::Array(Box::new((arr_content_concrete, arr_idx_concrete)))
}
}
}

impl SubmoduleTypecheckConstraint {
fn try_infer_parameters(&mut self, context: &mut InstantiationContext) {
let sm = &mut context.submodules[self.sm_id];

let sub_module = &context.linker.modules[sm.module_uuid];

for (id, p) in sm.port_map.iter_valids() {
let wire = &context.wires[p.maps_to_wire];

let mut wire_typ_clone = wire.typ.clone();
wire_typ_clone.fully_substitute(&context.type_substitutor);

let port_decl_instr = sub_module.ports[id].declaration_instruction;
let port_decl = sub_module.link_info.instructions[port_decl_instr].unwrap_declaration();

infer_parameters_by_walking_type(&port_decl.typ_expr, &wire_typ_clone, &mut sm.template_args, &sub_module.link_info);
}
/// Directly named type and value parameters are immediately unified, but latency count deltas can only be computed from the latency counting graph
fn try_infer_latency_counts(&mut self, context: &mut InstantiationContext) {
// TODO
}

}

impl DelayedConstraint<InstantiationContext<'_, '_>> for SubmoduleTypecheckConstraint {
fn try_apply(&mut self, context : &mut InstantiationContext) -> DelayedConstraintStatus {
// Try to infer template arguments based on the connections to the ports of the module
self.try_infer_parameters(context);
// Try to infer template arguments based on the connections to the ports of the module.
self.try_infer_latency_counts(context);

let sm = &context.submodules[self.sm_id];
let sm = &mut context.submodules[self.sm_id];

let submod_instr = context.md.link_info.instructions[sm.original_instruction].unwrap_submodule();
let sub_module = &context.linker.modules[sm.module_uuid];

// Check if there's any argument that isn't known
for (_id, arg) in &sm.template_args {
match arg {
ConcreteTemplateArg::NotProvided => {
return DelayedConstraintStatus::NoProgress;
}
ConcreteTemplateArg::Type(..) | ConcreteTemplateArg::Value(..) => {}
for (_id, arg) in &mut sm.template_args {
if !arg.fully_substitute(&context.type_substitutor) { // We don't actually *need* to already fully_substitute here, but it's convenient and saves some work
return DelayedConstraintStatus::NoProgress;
}
}

Expand Down
19 changes: 7 additions & 12 deletions src/instantiation/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::ops::{Deref, Index, IndexMut};

use crate::linker::IsExtern;
use crate::prelude::*;
use crate::typing::template::{GlobalReference, HowDoWeKnowTheTemplateArg};
use crate::typing::template::GlobalReference;

use num::BigInt;

Expand All @@ -19,7 +19,7 @@ use crate::value::{compute_binary_op, compute_unary_op, Value};
use crate::typing::{
abstract_type::DomainType,
concrete_type::{ConcreteType, INT_CONCRETE_TYPE},
template::{ConcreteTemplateArg, TemplateArgKind},
template::TemplateArgKind,
};

use super::*;
Expand Down Expand Up @@ -173,7 +173,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {
Ok(match typ {
WrittenType::Error(_) => caught_by_typecheck!("Error Type"),
WrittenType::TemplateVariable(_, template_id) => {
self.template_args[*template_id].unwrap_type().clone()
self.template_args[*template_id].clone()
}
WrittenType::Named(named_type) => {
ConcreteType::Named(crate::typing::concrete_type::ConcreteGlobalReference {
Expand Down Expand Up @@ -712,16 +712,11 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {
for (_id, v) in &submodule.module_ref.template_args {
template_args.alloc(match v {
Some(arg) => match &arg.kind {
TemplateArgKind::Type(typ) => ConcreteTemplateArg::Type(
self.concretize_type(typ)?,
HowDoWeKnowTheTemplateArg::Given,
),
TemplateArgKind::Value(v) => ConcreteTemplateArg::Value(
self.generation_state.get_generation_value(*v)?.clone(),
HowDoWeKnowTheTemplateArg::Given,
),
TemplateArgKind::Type(typ) => self.concretize_type(typ)?,
TemplateArgKind::Value(v) =>
ConcreteType::Value(self.generation_state.get_generation_value(*v)?.clone()),
},
None => ConcreteTemplateArg::NotProvided,
None => ConcreteType::Unknown(self.type_substitutor.alloc()),
});
}
SubModuleOrWire::SubModule(self.submodules.alloc(SubModule {
Expand Down
2 changes: 1 addition & 1 deletion src/instantiation/latency_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct LatencyDomainInfo {
}

impl RealWireDataSource {
fn iter_sources_with_min_latency<F: FnMut(WireID, i64)>(&self, mut f: F) {
fn iter_sources_with_min_latency(&self, mut f: impl FnMut(WireID, i64)) {
match self {
RealWireDataSource::ReadOnly => {}
RealWireDataSource::Multiplexer {
Expand Down
Loading

0 comments on commit ca04e50

Please sign in to comment.