diff --git a/README.md b/README.md index e86a42b..621c00a 100644 --- a/README.md +++ b/README.md @@ -47,3 +47,11 @@ The mcc compiler follows a standard compiler architecture, consisting of the fol - **IR Generation**: Converts the AST into a [three-address code intermediate representation](https://en.wikipedia.org/wiki/Three-address_code). - **Assembly Generation**: Translates the intermediate representation into assembly code. + +## Resources + +Here are some useful resources consulted by me when writing this compiler + +- [Writing a C Compiler](https://norasandler.com/book/) book +- Other small C compiler + implementations: [9cc](https://github.com/rui314/9cc) [chibicc](https://github.com/rui314/chibicc), [lacc](https://github.com/larmel/lacc), [cproc](https://github.com/michaelforney/cproc) \ No newline at end of file diff --git a/include/mcc/ast.h b/include/mcc/ast.h index 8c742ce..d48a6e7 100644 --- a/include/mcc/ast.h +++ b/include/mcc/ast.h @@ -3,6 +3,7 @@ #include "source_location.h" #include "str.h" +#include "type.h" #include @@ -62,6 +63,7 @@ typedef enum BinaryOpType { } BinaryOpType; typedef struct Expr Expr; +typedef struct Variable Variable; // identifier struct ConstExpr { int32_t val; @@ -93,19 +95,20 @@ struct CallExpr { typedef struct Expr { SourceRange source_range; ExprTag tag; + const Type* type; // C type (e.g. void, int, int*) union { struct ConstExpr const_expr; struct UnaryOpExpr unary_op; struct BinaryOpExpr binary_op; - struct StringView variable; + const Variable* variable; struct TernaryExpr ternary; struct CallExpr call; }; } Expr; typedef struct VariableDecl { - StringView name; - const Expr* initializer; // An optional initializer + Variable* name; + Expr* initializer; // An optional initializer } VariableDecl; typedef enum StmtTag { @@ -149,24 +152,24 @@ struct Stmt { struct ReturnStmt { Expr* expr; } ret; - const Expr* expr; + Expr* expr; struct IfStmt { - const Expr* cond; - const Stmt* then; - const Stmt* els; // optional, can be nullptr + Expr* cond; + Stmt* then; + Stmt* els; // optional, can be nullptr } if_then; // while or do while loop struct While { - const Expr* cond; - const Stmt* body; + Expr* cond; + Stmt* body; } while_loop; struct For { ForInit init; - const Expr* cond; // optional, can be nullptr - const Expr* post; // optional, can be nullptr - const Stmt* body; + Expr* cond; // optional, can be nullptr + Expr* post; // optional, can be nullptr + Stmt* body; } for_loop; }; }; @@ -184,13 +187,9 @@ typedef struct BlockItem { }; } BlockItem; -typedef struct Parameter { - StringView name; -} Parameter; - typedef struct Parameters { uint32_t length; - Parameter* data; + Variable** data; } Parameters; typedef struct FunctionDecl { diff --git a/include/mcc/cli_args.h b/include/mcc/cli_args.h index 2bb2c10..102f53c 100644 --- a/include/mcc/cli_args.h +++ b/include/mcc/cli_args.h @@ -6,6 +6,7 @@ typedef struct CliArgs { const char* source_filename; // Filename of the source file (with extension) bool stop_after_lexer; bool stop_after_parser; + bool stop_after_semantic_analysis; bool gen_ir_only; // Stop after generating the IR bool codegen_only; // generate assembly; but does not save to a file diff --git a/include/mcc/ir.h b/include/mcc/ir.h index dd59e96..0d85853 100644 --- a/include/mcc/ir.h +++ b/include/mcc/ir.h @@ -15,7 +15,7 @@ typedef struct IRGenerationResult { } IRGenerationResult; struct TranslationUnit; -IRGenerationResult ir_generate(struct TranslationUnit* ast, +IRGenerationResult ir_generate(const struct TranslationUnit* ast, Arena* permanent_arena, Arena scratch_arena); void print_ir(const struct IRProgram* ir); diff --git a/include/mcc/sema.h b/include/mcc/sema.h new file mode 100644 index 0000000..ff7b16b --- /dev/null +++ b/include/mcc/sema.h @@ -0,0 +1,13 @@ +#ifndef MCC_SEMA_H +#define MCC_SEMA_H + +// Semantic analysis pass + +#include "arena.h" +#include "diagnostic.h" + +typedef struct TranslationUnit TranslationUnit; + +ErrorsView type_check(TranslationUnit* ast, Arena* permanent_arena); + +#endif // MCC_SEMA_H diff --git a/include/mcc/token.h b/include/mcc/token.h index 37dd78a..78bc5ad 100644 --- a/include/mcc/token.h +++ b/include/mcc/token.h @@ -4,7 +4,7 @@ #include "source_location.h" #include "str.h" -typedef enum TokenType : char { +typedef enum TokenTag : char { TOKEN_INVALID = 0, TOKEN_LEFT_PAREN, // ( @@ -75,17 +75,17 @@ typedef enum TokenType : char { TOKEN_EOF, TOKEN_TYPES_COUNT, -} TokenType; +} TokenTag; typedef struct Token { - TokenType type; + TokenTag tag; uint32_t start; // The offset of the starting character in a token uint32_t size; } Token; /// @brief An SOA view of tokens typedef struct Tokens { - TokenType* token_types; + TokenTag* token_types; uint32_t* token_starts; uint32_t* token_sizes; uint32_t token_count; @@ -94,7 +94,7 @@ typedef struct Tokens { inline static Token get_token(const Tokens* tokens, uint32_t i) { MCC_ASSERT(i < tokens->token_count); - return MCC_COMPOUND_LITERAL(Token){.type = tokens->token_types[i], + return MCC_COMPOUND_LITERAL(Token){.tag = tokens->token_types[i], .start = tokens->token_starts[i], .size = tokens->token_sizes[i]}; } diff --git a/include/mcc/type.h b/include/mcc/type.h new file mode 100644 index 0000000..3e9e541 --- /dev/null +++ b/include/mcc/type.h @@ -0,0 +1,44 @@ +#ifndef MCC_TYPE_H +#define MCC_TYPE_H + +#include +#include + +#include "arena.h" +#include "str.h" + +typedef enum TypeTag : uint8_t { + TYPE_INVALID = 0, + TYPE_VOID, + TYPE_INTEGER, + TYPE_FUNCTION, +} TypeTag; + +typedef struct Type { + uint32_t size; + uint32_t alignment; + alignas(max_align_t) TypeTag tag; +} Type; + +typedef struct IntegerType { + Type base; + bool is_unsigned; + const char* name; +} IntegerType; + +typedef struct FunctionType { + Type base; + const Type* return_type; + uint32_t param_count; +} FunctionType; + +extern const Type* typ_invalid; +extern const Type* typ_void; +extern const Type* typ_int; + +const Type* func_type(const Type* return_type, uint32_t param_count, + Arena* arena); + +void format_type_to(StringBuffer* buffer, const Type* typ); + +#endif // MCC_TYPE_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 06c2e44..3984ee2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -13,6 +13,8 @@ add_library(mcc_lib ${include_dir}/ir.h ${include_dir}/x86.h ${include_dir}/dynarray.h + ${include_dir}/type.h + ${include_dir}/sema.h utils/format.c utils/str.c @@ -26,6 +28,10 @@ add_library(mcc_lib frontend/ast_printer.c frontend/parser.c frontend/token.c + frontend/type.c + frontend/type_check.c + frontend/symbol_table.h + frontend/symbol_table.c ir/ir_generator.c ir/ir_printer.c diff --git a/src/frontend/ast_printer.c b/src/frontend/ast_printer.c index 2923485..973586d 100644 --- a/src/frontend/ast_printer.c +++ b/src/frontend/ast_printer.c @@ -1,6 +1,8 @@ #include #include +#include "symbol_table.h" + static void print_str(StringView str) { printf("%.*s", (int)str.size, str.start); @@ -86,7 +88,7 @@ static void ast_print_expr(const Expr* expr, int indent) printf("%*sVariableExpr ", indent, ""); print_source_range(expr->source_range); printf(" "); - print_str(expr->variable); + print_str(expr->variable->name); printf("\n"); break; case EXPR_TERNARY: @@ -134,7 +136,7 @@ static void ast_print_decl(const VariableDecl* decl, int indent) { printf("%*sVariableDecl ", indent, ""); printf("int "); - print_str(decl->name); + print_str(decl->name->name); printf("\n"); if (decl->initializer) { ast_print_expr(decl->initializer, indent + 2); } } @@ -219,10 +221,10 @@ static void ast_print_parameters(Parameters parameters) printf("("); for (uint32_t i = 0; i < parameters.length; ++i) { if (i > 0) { printf(", "); } - const Parameter param = parameters.data[i]; + const Variable* param = parameters.data[i]; printf("int"); - if (param.name.size != 0) { - printf(" %.*s", (int)param.name.size, param.name.start); + if (param->name.size != 0) { + printf(" %.*s", (int)param->name.size, param->name.start); } } printf(")\"\n"); diff --git a/src/frontend/lexer.c b/src/frontend/lexer.c index 9e6f0a3..3261b9a 100644 --- a/src/frontend/lexer.c +++ b/src/frontend/lexer.c @@ -31,10 +31,10 @@ static bool lexer_is_at_end(const Lexer* lexer) return *lexer->current == '\0'; } -static Token make_token(const Lexer* lexer, TokenType type) +static Token make_token(const Lexer* lexer, TokenTag type) { return (Token){ - .type = type, + .tag = type, .start = u32_from_isize(lexer->previous - lexer->start), .size = u32_from_isize(lexer->current - lexer->previous), }; @@ -146,8 +146,8 @@ static bool eat_char_if_match(Lexer* lexer, const char expected) return true; } -static TokenType check_keyword(Lexer* lexer, int start_position, - StringView rest, TokenType type) +static TokenTag check_keyword(Lexer* lexer, int start_position, StringView rest, + TokenTag type) { if (lexer->current - lexer->previous == start_position + (int)rest.size && memcmp(lexer->previous + start_position, rest.start, rest.size) == 0) { @@ -157,7 +157,7 @@ static TokenType check_keyword(Lexer* lexer, int start_position, return TOKEN_IDENTIFIER; } -static TokenType get_identifier_type(Lexer* lexer) +static TokenTag get_identifier_type(Lexer* lexer) { switch (lexer->previous[0]) { case 'b': return check_keyword(lexer, 1, str("reak"), TOKEN_KEYWORD_BREAK); @@ -202,14 +202,14 @@ static Token scan_symbol(Lexer* lexer) case '}': return make_token(lexer, TOKEN_RIGHT_BRACE); case ';': return make_token(lexer, TOKEN_SEMICOLON); case '+': { - const TokenType token_type = eat_char_if_match(lexer, '+') ? TOKEN_PLUS_PLUS - : eat_char_if_match(lexer, '=') - ? TOKEN_PLUS_EQUAL - : TOKEN_PLUS; + const TokenTag token_type = eat_char_if_match(lexer, '+') ? TOKEN_PLUS_PLUS + : eat_char_if_match(lexer, '=') + ? TOKEN_PLUS_EQUAL + : TOKEN_PLUS; return make_token(lexer, token_type); } case '-': { - const TokenType token_type = + const TokenTag token_type = eat_char_if_match(lexer, '-') ? TOKEN_MINUS_MINUS : eat_char_if_match(lexer, '=') ? TOKEN_MINUS_EQUAL : eat_char_if_match(lexer, '>') ? TOKEN_MINUS_GREATER @@ -227,17 +227,17 @@ static Token scan_symbol(Lexer* lexer) : TOKEN_PERCENT); case '~': return make_token(lexer, TOKEN_TILDE); case '&': { - const TokenType token_type = + const TokenTag token_type = eat_char_if_match(lexer, '&') ? TOKEN_AMPERSAND_AMPERSAND : eat_char_if_match(lexer, '=') ? TOKEN_AMPERSAND_EQUAL : TOKEN_AMPERSAND; return make_token(lexer, token_type); } case '|': { - const TokenType token_type = eat_char_if_match(lexer, '|') ? TOKEN_BAR_BAR - : eat_char_if_match(lexer, '=') - ? TOKEN_BAR_EQUAL - : TOKEN_BAR; + const TokenTag token_type = eat_char_if_match(lexer, '|') ? TOKEN_BAR_BAR + : eat_char_if_match(lexer, '=') + ? TOKEN_BAR_EQUAL + : TOKEN_BAR; return make_token(lexer, token_type); } case '^': @@ -250,7 +250,7 @@ static Token scan_symbol(Lexer* lexer) return make_token(lexer, eat_char_if_match(lexer, '=') ? TOKEN_NOT_EQUAL : TOKEN_NOT); case '<': { - const TokenType token_type = + const TokenTag token_type = eat_char_if_match(lexer, '<') ? (eat_char_if_match(lexer, '=') ? TOKEN_LESS_LESS_EQUAL : TOKEN_LESS_LESS) @@ -259,7 +259,7 @@ static Token scan_symbol(Lexer* lexer) return make_token(lexer, token_type); } case '>': { - const TokenType token_type = + const TokenTag token_type = eat_char_if_match(lexer, '>') ? (eat_char_if_match(lexer, '=') ? TOKEN_GREATER_GREATER_EQUAL : TOKEN_GREATER_GREATER) @@ -294,7 +294,7 @@ static Token scan_token(Lexer* lexer) struct TokenTypeDynArray { size_t length; size_t capacity; - TokenType* data; + TokenTag* data; }; struct U32DynArray { @@ -313,23 +313,23 @@ Tokens lex(const char* source, Arena* permanent_arena, Arena scratch_arena) while (true) { const Token token = scan_token(&lexer); - DYNARRAY_PUSH_BACK(&token_types_dyn_array, TokenType, &scratch_arena, - token.type); + DYNARRAY_PUSH_BACK(&token_types_dyn_array, TokenTag, &scratch_arena, + token.tag); DYNARRAY_PUSH_BACK(&token_starts_dyn_array, uint32_t, &scratch_arena, token.start); DYNARRAY_PUSH_BACK(&token_sizes_dyn_array, uint32_t, &scratch_arena, token.size); - if (token.type == TOKEN_EOF) { break; } + if (token.tag == TOKEN_EOF) { break; } } const uint32_t token_count = u32_from_usize(token_types_dyn_array.length); MCC_ASSERT(token_starts_dyn_array.length == token_count); MCC_ASSERT(token_sizes_dyn_array.length == token_count); - TokenType* token_types = - ARENA_ALLOC_ARRAY(permanent_arena, TokenType, token_count); + TokenTag* token_types = + ARENA_ALLOC_ARRAY(permanent_arena, TokenTag, token_count); memcpy(token_types, token_types_dyn_array.data, - token_count * sizeof(TokenType)); + token_count * sizeof(TokenTag)); uint32_t* token_starts = ARENA_ALLOC_ARRAY(permanent_arena, uint32_t, token_count); diff --git a/src/frontend/parser.c b/src/frontend/parser.c index 8ecac93..9998e50 100644 --- a/src/frontend/parser.c +++ b/src/frontend/parser.c @@ -8,7 +8,9 @@ #include #include -struct ParseErrorVec { +#include "symbol_table.h" + +struct ErrorVec { size_t length; size_t capacity; Error* data; @@ -26,7 +28,7 @@ typedef struct Parser { bool has_error; bool in_panic_mode; - struct ParseErrorVec errors; + struct ErrorVec errors; struct Scope* global_scope; } Parser; @@ -35,7 +37,7 @@ typedef struct Parser { static SourceRange token_source_range(Token token) { const uint32_t begin = - (token.type == TOKEN_EOF) ? token.start - 1 : token.start; + (token.tag == TOKEN_EOF) ? token.start - 1 : token.start; return (SourceRange){.begin = begin, .end = token.start + token.size}; } @@ -101,33 +103,33 @@ static Token parser_previous_token(Parser* parser) return get_token(&parser->tokens, previous_token_index); } -static bool token_match_or_eof(const Parser* parser, TokenType typ) +static bool token_match_or_eof(const Parser* parser, TokenTag typ) { const Token current_token = parser_current_token(parser); - return current_token.type == typ || current_token.type == TOKEN_EOF; + return current_token.tag == typ || current_token.tag == TOKEN_EOF; } // Advance tokens by one // Also skip any error tokens static void parse_advance(Parser* parser) { - if (parser_current_token(parser).type == TOKEN_EOF) { return; } + if (parser_current_token(parser).tag == TOKEN_EOF) { return; } for (;;) { parser->current_token_index++; Token current = parser_current_token(parser); - if (current.type != TOKEN_ERROR) break; + if (current.tag != TOKEN_ERROR) break; parse_panic_at_token(parser, str("unexpected character"), current); } } // Consume the current token. If the token doesn't have specified type, generate // an error. -static void parse_consume(Parser* parser, TokenType type, const char* error_msg) +static void parse_consume(Parser* parser, TokenTag type, const char* error_msg) { const Token current = parser_current_token(parser); - if (current.type == type) { + if (current.tag == type) { parse_advance(parser); return; } @@ -136,77 +138,6 @@ static void parse_consume(Parser* parser, TokenType type, const char* error_msg) } #pragma endregion -#pragma region Scope and name resolution -typedef struct Name { - StringView - name; // name in the source. This is the name used for variable lookup - StringView rewrote_name; // name after alpha renaming - uint32_t shadow_counter; // increase each time we have shadowing -} Name; - -// Represents a block scope -// TODO: use hash table -typedef struct Scope { - uint32_t length; - uint32_t capacity; - Name* data; - struct Scope* parent; -} Scope; - -static struct Scope* new_scope(Scope* parent, Arena* arena) -{ - struct Scope* map = ARENA_ALLOC_OBJECT(arena, Scope); - *map = (struct Scope){ - .parent = parent, - }; - return map; -} - -static Name* lookup_name(const Scope* scope, StringView name) -{ - for (uint32_t i = 0; i < scope->length; ++i) { - if (str_eq(name, scope->data[i].name)) { return &scope->data[i]; } - } - - if (scope->parent == nullptr) return nullptr; - - return lookup_name(scope->parent, name); -} - -// Return nullptr if a variable already exist in the same scope -static Name* add_name(StringView name, Scope* scope, Arena* arena) -{ - // check variable in current scope - for (uint32_t i = 0; i < scope->length; ++i) { - if (str_eq(name, scope->data[i].name)) { return nullptr; } - } - - // lookup variable in parent scopes - const Name* parent_variable = nullptr; - if (scope->parent) { parent_variable = lookup_name(scope->parent, name); } - - Name variable; - if (parent_variable == nullptr) { - variable = (Name){ - .name = name, - .rewrote_name = name, - .shadow_counter = 0, - }; - } else { - uint32_t shadow_counter = parent_variable->shadow_counter + 1; - variable = (Name){ - .name = name, - .rewrote_name = allocate_printf(arena, "%.*s.%i", (int)name.size, - name.start, shadow_counter), - .shadow_counter = shadow_counter + 1, - }; - } - - DYNARRAY_PUSH_BACK(scope, Name, arena, variable); - return &scope->data[scope->length - 1]; -} -#pragma endregion - #pragma region Expression parsing static Expr* parse_number_literal(Parser* parser, Scope* scope) { @@ -223,6 +154,7 @@ static Expr* parse_number_literal(Parser* parser, Scope* scope) Expr* result = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); *result = (Expr){.tag = EXPR_CONST, + .type = typ_invalid, .source_range = token_source_range(token), .const_expr = (struct ConstExpr){.val = val}}; return result; @@ -232,7 +164,7 @@ static Expr* parse_identifier_expr(Parser* parser, Scope* scope) { const Token token = parser_previous_token(parser); - MCC_ASSERT(token.type == TOKEN_IDENTIFIER); + MCC_ASSERT(token.tag == TOKEN_IDENTIFIER); const StringView identifier = str_from_token(parser->src, token); @@ -240,20 +172,19 @@ static Expr* parse_identifier_expr(Parser* parser, Scope* scope) Expr* result = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); // If local variable does not exist - const Name* variable = lookup_name(scope, identifier); + const Variable* variable = lookup_variable(scope, identifier); if (!variable) { const StringView error_msg = allocate_printf( parser->permanent_arena, "use of undeclared identifier '%.*s'", (int)identifier.size, identifier.start); parse_error_at(parser, error_msg, token_source_range(token)); - *result = (Expr){.tag = EXPR_VARIABLE, - .source_range = token_source_range(token), - .variable = identifier}; - } else { - *result = (Expr){.tag = EXPR_VARIABLE, - .source_range = token_source_range(token), - .variable = variable->rewrote_name}; } + + // TODO: handle the case wher variable == nullptr + *result = (Expr){.tag = EXPR_VARIABLE, + .type = typ_invalid, + .source_range = token_source_range(token), + .variable = variable}; return result; } @@ -346,7 +277,7 @@ static ParseRule rules[TOKEN_TYPES_COUNT] = { _Static_assert(sizeof(rules) / sizeof(ParseRule) == TOKEN_TYPES_COUNT, "Parse rule table should contain all token types"); -static ParseRule* get_rule(TokenType operator_type) +static ParseRule* get_rule(TokenTag operator_type) { return &rules[operator_type]; } @@ -358,13 +289,14 @@ static Expr* parse_precedence(Parser* parser, Precedence precedence, const Token previous_token = parser_previous_token(parser); - const PrefixParseFn prefix_rule = get_rule(previous_token.type)->prefix; + const PrefixParseFn prefix_rule = get_rule(previous_token.tag)->prefix; if (prefix_rule == NULL) { parse_panic_at_token(parser, str("Expect valid expression"), previous_token); Expr* error_expr = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); *error_expr = (Expr){ .tag = EXPR_INVALID, + .type = typ_invalid, .source_range = token_source_range(previous_token), }; return error_expr; @@ -372,11 +304,10 @@ static Expr* parse_precedence(Parser* parser, Precedence precedence, Expr* expr = prefix_rule(parser, scope); - while (precedence <= - get_rule(parser_current_token(parser).type)->precedence) { + while (precedence <= get_rule(parser_current_token(parser).tag)->precedence) { parse_advance(parser); InfixParseFn infix_rule = - get_rule(parser_previous_token(parser).type)->infix; + get_rule(parser_previous_token(parser).tag)->infix; expr = infix_rule(parser, expr, scope); } @@ -395,7 +326,7 @@ static Expr* parse_unary_op(Parser* parser, Scope* scope) Token operator_token = parser_previous_token(parser); UnaryOpType operator_type; - switch (operator_token.type) { + switch (operator_token.tag) { case TOKEN_MINUS: operator_type = UNARY_OP_NEGATION; break; case TOKEN_TILDE: operator_type = UNARY_OP_BITWISE_TYPE_COMPLEMENT; break; case TOKEN_NOT: operator_type = UNARY_OP_NOT; break; @@ -405,6 +336,8 @@ static Expr* parse_unary_op(Parser* parser, Scope* scope) // Inner expression Expr* expr = parse_precedence(parser, PREC_UNARY, scope); + // TODO: type check + // build result // TODO: better way to handle the case where expr == NULL SourceRange result_source_range = @@ -414,6 +347,7 @@ static Expr* parse_unary_op(Parser* parser, Scope* scope) Expr* result = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); *result = (Expr){.tag = EXPR_UNARY, + .type = typ_invalid, .source_range = result_source_range, .unary_op = (struct UnaryOpExpr){ .unary_op_type = operator_type, .inner_expr = expr}}; @@ -421,7 +355,7 @@ static Expr* parse_unary_op(Parser* parser, Scope* scope) return result; } -static BinaryOpType binop_type_from_token_type(TokenType token_type) +static BinaryOpType binop_type_from_token_type(TokenTag token_type) { switch (token_type) { case TOKEN_PLUS: return BINARY_OP_PLUS; @@ -464,7 +398,7 @@ static Expr* parse_binary_op(Parser* parser, Expr* lhs_expr, { Token operator_token = parser_previous_token(parser); - const TokenType operator_type = operator_token.type; + const TokenTag operator_type = operator_token.tag; const ParseRule* rule = get_rule(operator_type); Expr* rhs_expr = parse_precedence( parser, @@ -474,6 +408,8 @@ static Expr* parse_binary_op(Parser* parser, Expr* lhs_expr, BinaryOpType binary_op_type = binop_type_from_token_type(operator_type); + // TODO: type check + // build result const SourceRange result_source_range = source_range_union(source_range_union(token_source_range(operator_token), @@ -482,6 +418,7 @@ static Expr* parse_binary_op(Parser* parser, Expr* lhs_expr, Expr* result = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); *result = (Expr){.tag = EXPR_BINARY, + .type = typ_invalid, .source_range = result_source_range, .binary_op = (struct BinaryOpExpr){ .binary_op_type = binary_op_type, @@ -512,8 +449,11 @@ static Expr* parse_ternary(Parser* parser, Expr* cond, struct Scope* scope) parse_consume(parser, TOKEN_COLON, "expect ':'"); Expr* false_expr = parse_precedence(parser, PREC_TERNARY, scope); + // TODO: type check + Expr* result = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); *result = (Expr){.tag = EXPR_TERNARY, + .type = typ_invalid, .source_range = source_range_union(cond->source_range, false_expr->source_range), .ternary = (struct TernaryExpr){.cond = cond, @@ -534,8 +474,8 @@ static Expr* parse_function_call(Parser* parser, Expr* function, struct ExprVec args_vec = {}; while (!token_match_or_eof(parser, TOKEN_RIGHT_PAREN)) { - Expr* expr = parse_expr(parser, scope); - DYNARRAY_PUSH_BACK(&args_vec, Expr*, &parser->scratch_arena, expr); + Expr* arg = parse_expr(parser, scope); + DYNARRAY_PUSH_BACK(&args_vec, Expr*, &parser->scratch_arena, arg); if (token_match_or_eof(parser, TOKEN_RIGHT_PAREN)) break; parse_consume(parser, TOKEN_COMMA, "expect ','"); } @@ -543,15 +483,16 @@ static Expr* parse_function_call(Parser* parser, Expr* function, parse_consume(parser, TOKEN_RIGHT_PAREN, "expect ')' at the end of a function call"); - Expr* expr = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); - const uint32_t arg_count = u32_from_usize(args_vec.length); + Expr** args = ARENA_ALLOC_ARRAY(parser->permanent_arena, Expr*, arg_count); if (args_vec.length != 0) { memcpy(args, args_vec.data, args_vec.length * sizeof(Expr*)); } + Expr* expr = ARENA_ALLOC_OBJECT(parser->permanent_arena, Expr); *expr = (Expr){.tag = EXPR_CALL, + .type = typ_invalid, .source_range = source_range_union( function->source_range, token_source_range(parser_previous_token(parser))), @@ -560,7 +501,6 @@ static Expr* parse_function_call(Parser* parser, Expr* function, .arg_count = arg_count, .args = args, }}; - return expr; } @@ -591,10 +531,10 @@ static VariableDecl parse_decl(Parser* parser, struct Scope* scope) { const Token name = parser_current_token(parser); // TODO: proper error handling - MCC_ASSERT(name.type == TOKEN_IDENTIFIER); + MCC_ASSERT(name.tag == TOKEN_IDENTIFIER); const StringView identifier = str_from_token(parser->src, name); - const Name* variable = add_name(identifier, scope, parser->permanent_arena); + Variable* variable = add_variable(identifier, scope, parser->permanent_arena); if (!variable) { const StringView error_msg = allocate_printf(parser->permanent_arena, "redefinition of '%.*s'", @@ -604,22 +544,22 @@ static VariableDecl parse_decl(Parser* parser, struct Scope* scope) parse_advance(parser); - const Expr* initializer = nullptr; - if (parser_current_token(parser).type == TOKEN_EQUAL) { + Expr* initializer = nullptr; + if (parser_current_token(parser).tag == TOKEN_EQUAL) { parse_advance(parser); initializer = parse_expr(parser, scope); } parse_consume(parser, TOKEN_SEMICOLON, "expect ';'"); - return (VariableDecl){.name = variable ? variable->rewrote_name : identifier, - .initializer = initializer}; + // TODO: handle the case where variable == nullptr + return (VariableDecl){.name = variable, .initializer = initializer}; } // Find the next synchronization token (`}` or `;`) static void parser_panic_synchronize(Parser* parser) { while (true) { - const TokenType current_token_type = parser_current_token(parser).type; + const TokenTag current_token_type = parser_current_token(parser).tag; if (current_token_type == TOKEN_RIGHT_BRACE || current_token_type == TOKEN_SEMICOLON || current_token_type == TOKEN_EOF) { @@ -633,7 +573,7 @@ static BlockItem parse_block_item(Parser* parser, Scope* scope) { const Token current_token = parser_current_token(parser); BlockItem result; - if (current_token.type == TOKEN_KEYWORD_INT) { + if (current_token.tag == TOKEN_KEYWORD_INT) { parse_advance(parser); result = (BlockItem){.tag = BLOCK_ITEM_DECL, .decl = parse_decl(parser, scope)}; @@ -674,14 +614,14 @@ static Block parse_block(Parser* parser, Scope* parent_scope) struct IfStmt parse_if_stmt(Parser* parser, Scope* scope) { parse_consume(parser, TOKEN_LEFT_PAREN, "expect '('"); - const Expr* cond = parse_expr(parser, scope); + Expr* cond = parse_expr(parser, scope); parse_consume(parser, TOKEN_RIGHT_PAREN, "expect ')'"); Stmt* then = ARENA_ALLOC_OBJECT(parser->permanent_arena, Stmt); *then = parse_stmt(parser, scope); Stmt* els = nullptr; - if (parser_current_token(parser).type == TOKEN_KEYWORD_ELSE) { + if (parser_current_token(parser).tag == TOKEN_KEYWORD_ELSE) { parse_advance(parser); els = ARENA_ALLOC_OBJECT(parser->permanent_arena, Stmt); @@ -701,7 +641,7 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) Stmt result; - switch (start_token.type) { + switch (start_token.tag) { case TOKEN_SEMICOLON: { parse_advance(parser); @@ -747,7 +687,7 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) case TOKEN_KEYWORD_WHILE: { parse_advance(parser); parse_consume(parser, TOKEN_LEFT_PAREN, "expect '('"); - const Expr* cond = parse_expr(parser, scope); + Expr* cond = parse_expr(parser, scope); parse_consume(parser, TOKEN_RIGHT_PAREN, "expect ')'"); Stmt* body = ARENA_ALLOC_OBJECT(parser->permanent_arena, Stmt); @@ -763,7 +703,7 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) parse_consume(parser, TOKEN_KEYWORD_WHILE, "expect \"while\""); parse_consume(parser, TOKEN_LEFT_PAREN, "expect '('"); - const Expr* cond = parse_expr(parser, scope); + Expr* cond = parse_expr(parser, scope); parse_consume(parser, TOKEN_RIGHT_PAREN, "expect ')'"); parse_consume(parser, TOKEN_SEMICOLON, "expect ';' after do/while statement"); @@ -777,7 +717,7 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) // init ForInit init = {}; - switch (parser_current_token(parser).type) { + switch (parser_current_token(parser).tag) { case TOKEN_KEYWORD_INT: { // for loop introduce a new scope scope = new_scope(scope, parser->permanent_arena); @@ -799,13 +739,13 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) } // cond - Expr* cond = parser_current_token(parser).type == TOKEN_SEMICOLON + Expr* cond = parser_current_token(parser).tag == TOKEN_SEMICOLON ? nullptr : parse_expr(parser, scope); parse_consume(parser, TOKEN_SEMICOLON, "expect ';'"); // post - Expr* post = parser_current_token(parser).type == TOKEN_RIGHT_PAREN + Expr* post = parser_current_token(parser).tag == TOKEN_RIGHT_PAREN ? nullptr : parse_expr(parser, scope); parse_consume(parser, TOKEN_RIGHT_PAREN, "expect ')'"); @@ -823,7 +763,7 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) break; } default: { - const Expr* expr = parse_expr(parser, scope); + Expr* expr = parse_expr(parser, scope); parse_consume(parser, TOKEN_SEMICOLON, "expect ';'"); result = (Stmt){.tag = STMT_EXPR, .expr = expr}; break; @@ -839,14 +779,14 @@ static Stmt parse_stmt(Parser* parser, Scope* scope) struct ParameterVec { uint32_t length; uint32_t capacity; - Parameter* data; + Variable** data; }; -static Parameters parse_parameter_list(Parser* parser) +static Parameters parse_parameter_list(Parser* parser, Scope* scope) { parse_consume(parser, TOKEN_LEFT_PAREN, "Expect ("); - if (parser_current_token(parser).type == TOKEN_KEYWORD_VOID) { + if (parser_current_token(parser).tag == TOKEN_KEYWORD_VOID) { parse_advance(parser); parse_consume(parser, TOKEN_RIGHT_PAREN, "Expect )"); @@ -861,21 +801,20 @@ static Parameters parse_parameter_list(Parser* parser) while (!token_match_or_eof(parser, TOKEN_RIGHT_PAREN)) { const Token current_token = parser_current_token(parser); - switch (current_token.type) { + switch (current_token.tag) { case TOKEN_KEYWORD_VOID: { parse_error_at( parser, str("'void' must be the first and only parameter if specified"), token_source_range(current_token)); parse_advance(parser); - break; - } + } break; case TOKEN_KEYWORD_INT: { parse_advance(parser); const Token identifier_token = parser_current_token(parser); StringView identifier = {}; - if (identifier_token.type == TOKEN_IDENTIFIER) { + if (identifier_token.tag == TOKEN_IDENTIFIER) { identifier = str_from_token(parser->src, identifier_token); parse_advance(parser); } @@ -884,28 +823,29 @@ static Parameters parse_parameter_list(Parser* parser) parse_consume(parser, TOKEN_COMMA, "Expect ','"); } - const Parameter parameter = { - .name = identifier, - }; - DYNARRAY_PUSH_BACK(¶meters_vec, Parameter, &parser->scratch_arena, - parameter); + Variable* name = add_variable(identifier, scope, parser->permanent_arena); + // TODO: error handling + MCC_ASSERT(name != nullptr); + + DYNARRAY_PUSH_BACK(¶meters_vec, Variable*, &parser->scratch_arena, + name); } break; default: parse_panic_at_token(parser, str("Expect parameter declarator"), current_token); parse_advance(parser); + break; } - break; } parse_consume(parser, TOKEN_RIGHT_PAREN, "Expect )"); - Parameter* params = ARENA_ALLOC_ARRAY(parser->permanent_arena, Parameter, + Variable** params = ARENA_ALLOC_ARRAY(parser->permanent_arena, Variable*, parameters_vec.length); if (parameters_vec.length != 0) { memcpy(params, parameters_vec.data, - parameters_vec.length * sizeof(Parameter)); + parameters_vec.length * sizeof(Variable*)); } // TODO: warn when the parameter list is empty in pre-C23 mode @@ -920,7 +860,7 @@ static StringView parse_identifier(Parser* parser) { const Token current_token = parser_current_token(parser); - if (current_token.type != TOKEN_IDENTIFIER) { + if (current_token.tag != TOKEN_IDENTIFIER) { parse_panic_at_token(parser, str("Expect Identifier"), current_token); } parse_advance(parser); @@ -933,20 +873,27 @@ static FunctionDecl* parse_function_decl(Parser* parser) parse_consume(parser, TOKEN_KEYWORD_INT, "Expect keyword int"); StringView function_name = parse_identifier(parser); - add_name(function_name, parser->global_scope, parser->permanent_arena); + Variable* name = add_variable(function_name, parser->global_scope, + parser->permanent_arena); + // TODO: support multiple function decl + MCC_ASSERT(name != nullptr); - const Parameters parameters = parse_parameter_list(parser); + Scope* function_scope = + new_scope(parser->global_scope, parser->permanent_arena); + const Parameters parameters = parse_parameter_list(parser, function_scope); Block* body = NULL; - if (parser_current_token(parser).type == TOKEN_LEFT_BRACE) { // is definition + if (parser_current_token(parser).tag == TOKEN_LEFT_BRACE) { // is definition parse_advance(parser); body = ARENA_ALLOC_OBJECT(parser->permanent_arena, Block); - *body = parse_block(parser, parser->global_scope); + *body = parse_block(parser, function_scope); } else { parse_consume(parser, TOKEN_SEMICOLON, "Expect ;"); } + name->type = func_type(typ_int, parameters.length, parser->permanent_arena); + FunctionDecl* decl = ARENA_ALLOC_OBJECT(parser->permanent_arena, FunctionDecl); *decl = (FunctionDecl){ @@ -970,7 +917,7 @@ static TranslationUnit* parse_translation_unit(Parser* parser) { struct FunctionDeclVec function_decl_vec = {}; - while (parser_current_token(parser).type != TOKEN_EOF) { + while (parser_current_token(parser).tag != TOKEN_EOF) { FunctionDecl* decl = parse_function_decl(parser); DYNARRAY_PUSH_BACK(&function_decl_vec, FunctionDecl*, &parser->scratch_arena, decl); diff --git a/src/frontend/symbol_table.c b/src/frontend/symbol_table.c new file mode 100644 index 0000000..99544ab --- /dev/null +++ b/src/frontend/symbol_table.c @@ -0,0 +1,64 @@ +#include "symbol_table.h" + +#include +#include + +// TODO: use hash table +struct Scope { + uint32_t length; + uint32_t capacity; + Variable* data; + struct Scope* parent; +}; + +Scope* new_scope(Scope* parent, Arena* arena) +{ + struct Scope* map = ARENA_ALLOC_OBJECT(arena, Scope); + *map = (struct Scope){ + .parent = parent, + }; + return map; +} + +Variable* lookup_variable(const Scope* scope, StringView name) +{ + for (uint32_t i = 0; i < scope->length; ++i) { + if (str_eq(name, scope->data[i].name)) { return &scope->data[i]; } + } + + if (scope->parent == nullptr) return nullptr; + + return lookup_variable(scope->parent, name); +} + +Variable* add_variable(StringView name, Scope* scope, Arena* arena) +{ + // check variable in current scope + for (uint32_t i = 0; i < scope->length; ++i) { + if (str_eq(name, scope->data[i].name)) { return nullptr; } + } + + // lookup variable in parent scopes + const Variable* parent_variable = nullptr; + if (scope->parent) { parent_variable = lookup_variable(scope->parent, name); } + + Variable variable; + if (parent_variable == nullptr) { + variable = (Variable){ + .name = name, + .rewrote_name = name, + .shadow_counter = 0, + }; + } else { + uint32_t shadow_counter = parent_variable->shadow_counter + 1; + variable = (Variable){ + .name = name, + .rewrote_name = allocate_printf(arena, "%.*s.%i", (int)name.size, + name.start, shadow_counter), + .shadow_counter = shadow_counter + 1, + }; + } + + DYNARRAY_PUSH_BACK(scope, Variable, arena, variable); + return &scope->data[scope->length - 1]; +} diff --git a/src/frontend/symbol_table.h b/src/frontend/symbol_table.h new file mode 100644 index 0000000..4f6478c --- /dev/null +++ b/src/frontend/symbol_table.h @@ -0,0 +1,27 @@ +#ifndef MCC_SYMBOL_TABLE_H +#define MCC_SYMBOL_TABLE_H + +#include +#include +#include +#include + +typedef struct Variable { + StringView + name; // name in the source. This is the name used for variable lookup + StringView rewrote_name; // name after alpha renaming + uint32_t shadow_counter; // increase each time we have shadowing + const Type* type; +} Variable; + +// Represents a block scope +typedef struct Scope Scope; + +Scope* new_scope(Scope* parent, Arena* arena); + +Variable* lookup_variable(const Scope* scope, StringView name); + +// Return nullptr if a variable already exist in the same scope +Variable* add_variable(StringView name, Scope* scope, Arena* arena); + +#endif // MCC_SYMBOL_TABLE_H diff --git a/src/frontend/token.c b/src/frontend/token.c index ccd1635..d2273a1 100644 --- a/src/frontend/token.c +++ b/src/frontend/token.c @@ -3,7 +3,7 @@ #include -static const char* token_type_string(TokenType type) +static const char* token_type_string(TokenTag type) { switch (type) { case TOKEN_INVALID: MCC_UNREACHABLE(); @@ -86,7 +86,7 @@ void print_tokens(const char* src, const Tokens* tokens, if (src_padding_size < 0) src_padding_size = 0; printf("%-10s src=\"%.*s\"%*s line=%-2i column=%-2i offset=%u\n", - token_type_string(token.type), (int)token.size, src + token.start, + token_type_string(token.tag), (int)token.size, src + token.start, src_padding_size, "", line_column.line, line_column.column, token.start); } diff --git a/src/frontend/type.c b/src/frontend/type.c new file mode 100644 index 0000000..7f1b798 --- /dev/null +++ b/src/frontend/type.c @@ -0,0 +1,68 @@ +#include +#include +#include + +const Type* typ_invalid = &(const Type){ + .tag = TYPE_INVALID, +}; + +const Type* typ_void = &(const Type){ + .tag = TYPE_VOID, +}; + +const Type* typ_int = (const Type*)&(const IntegerType){ + .base = + { + .tag = TYPE_INTEGER, + .size = 4, + .alignment = 4, + }, + .is_unsigned = true, + .name = "int", +}; + +const Type* func_type(const Type* return_type, uint32_t param_count, + Arena* arena) +{ + FunctionType* result = ARENA_ALLOC_OBJECT(arena, FunctionType); + *result = (FunctionType){ + .base = + { + .tag = TYPE_FUNCTION, + .size = 1, + .alignment = 1, + }, + .param_count = param_count, + .return_type = return_type, + }; + + return (const Type*)result; +} + +void format_type_to(StringBuffer* buffer, const Type* typ) +{ + switch (typ->tag) { + case TYPE_INVALID: MCC_UNREACHABLE(); return; + case TYPE_VOID: string_buffer_append(buffer, str("void")); return; + case TYPE_INTEGER: { + const IntegerType* int_typ = (const IntegerType*)typ; + string_buffer_append(buffer, str(int_typ->name)); + } + return; + case TYPE_FUNCTION: { + const FunctionType* func_typ = (const FunctionType*)typ; + format_type_to(buffer, func_typ->return_type); + string_buffer_push(buffer, '('); + if (func_typ->param_count != 0) { + for (uint32_t i = 0; i < func_typ->param_count; ++i) { + if (i != 0) { string_buffer_append(buffer, str(", ")); } + string_buffer_append(buffer, str("int")); + } + } else { + string_buffer_append(buffer, str("void")); + } + string_buffer_push(buffer, ')'); + } + return; + } +} diff --git a/src/frontend/type_check.c b/src/frontend/type_check.c new file mode 100644 index 0000000..6dc3347 --- /dev/null +++ b/src/frontend/type_check.c @@ -0,0 +1,308 @@ +#include +#include +#include +#include +#include + +#include "symbol_table.h" + +// All the type checking functions in this file return `false` to indicate +// encountering an error, which is used to skip further checks + +struct ErrorVec { + size_t length; + size_t capacity; + Error* data; +}; + +typedef struct Context { + struct ErrorVec errors; + Arena* permanent_arena; +} Context; + +#pragma region error reporter +static void error_at(StringView msg, SourceRange range, Context* context) +{ + Error error = (Error){.msg = msg, .range = range}; + DYNARRAY_PUSH_BACK(&context->errors, Error, context->permanent_arena, error); +} + +static void report_invalid_unary_args(const Expr* expr, Context* context) +{ + StringBuffer buffer = string_buffer_new(context->permanent_arena); + string_buffer_append(&buffer, str("invalid argument type '")); + format_type_to(&buffer, expr->unary_op.inner_expr->type); + string_buffer_append(&buffer, str("' to unary expression")); + error_at(str_from_buffer(&buffer), expr->unary_op.inner_expr->source_range, + context); +} + +static void report_invalid_binary_args(const Expr* expr, Context* context) +{ + StringBuffer buffer = string_buffer_new(context->permanent_arena); + string_buffer_append(&buffer, + str("invalid operands to binary expression ('")); + format_type_to(&buffer, expr->binary_op.lhs->type); + string_buffer_append(&buffer, str("' and '")); + format_type_to(&buffer, expr->binary_op.rhs->type); + string_buffer_append(&buffer, str("')")); + error_at(str_from_buffer(&buffer), expr->source_range, context); +} + +static void report_incompatible_return(const Expr* expr, Context* context) +{ + StringBuffer buffer = string_buffer_new(context->permanent_arena); + string_buffer_append(&buffer, str("returning '")); + format_type_to(&buffer, expr->type); + string_buffer_append( + &buffer, str("' from a function with incompatible result type 'int'")); + error_at(str_from_buffer(&buffer), expr->source_range, context); +} + +static void report_calling_noncallable(const Expr* function, Context* context) +{ + StringBuffer buffer = string_buffer_new(context->permanent_arena); + string_buffer_append(&buffer, str("called object with type '")); + format_type_to(&buffer, function->type); + string_buffer_append(&buffer, str("', which is not callable")); + error_at(str_from_buffer(&buffer), function->source_range, context); +} + +static void report_arg_count_mismatch(const Expr* function, + uint32_t param_count, uint32_t arg_count, + Context* context) +{ + const StringView msg = allocate_printf( + context->permanent_arena, + "too %s arguments to function call, expected %u, have %u", + param_count > arg_count ? "few" : "many", param_count, arg_count); + error_at(msg, function->source_range, context); +} + +static void report_wrong_arg_type(const Expr* arg, Context* context) +{ + StringBuffer buffer = string_buffer_new(context->permanent_arena); + string_buffer_append(&buffer, str("passing '")); + format_type_to(&buffer, arg->type); + string_buffer_append(&buffer, str("' to parameter of type 'int'")); + error_at(str_from_buffer(&buffer), arg->source_range, context); +} +#pragma endregion + +[[nodiscard]] +static bool type_check_expr(Expr* expr, Context* context); + +[[nodiscard]] +static bool type_check_function_call(Expr* function_call, Context* context) +{ + MCC_ASSERT(function_call->tag == EXPR_CALL); + + Expr* function = function_call->call.function; + if (!type_check_expr(function, context)) { return false; } + + if (function->type->tag != TYPE_FUNCTION) { + report_calling_noncallable(function, context); + return false; + } + + const FunctionType* function_type = (const FunctionType*)function->type; + + const uint32_t arg_count = function_call->call.arg_count; + if (function_type->param_count != arg_count) { + report_arg_count_mismatch(function, function_type->param_count, arg_count, + context); + return false; + } + + for (uint32_t i = 0; i < arg_count; ++i) { + Expr* arg = function_call->call.args[i]; + if (!type_check_expr(arg, context)) { return false; } + + if (arg->type->tag != TYPE_INTEGER) { + report_wrong_arg_type(arg, context); + return false; + } + } + + function_call->type = typ_int; + return true; +} + +[[nodiscard]] +static bool type_check_expr(Expr* expr, Context* context) +{ + expr->type = typ_invalid; + switch (expr->tag) { + case EXPR_INVALID: MCC_UNREACHABLE(); break; + case EXPR_CONST: expr->type = typ_int; return true; + case EXPR_VARIABLE: + MCC_ASSERT(expr->variable->type != nullptr); + MCC_ASSERT(expr->variable->type != TYPE_INVALID); + + expr->type = expr->variable->type; + return true; + case EXPR_UNARY: + if (!type_check_expr(expr->unary_op.inner_expr, context)) { return false; } + if (expr->unary_op.inner_expr->type->tag != TYPE_INTEGER) { + report_invalid_unary_args(expr, context); + return false; + } + expr->type = expr->unary_op.inner_expr->type; + return true; + case EXPR_BINARY: + if (!type_check_expr(expr->binary_op.lhs, context) || + !type_check_expr(expr->binary_op.rhs, context)) { + return false; + } + + if (expr->binary_op.lhs->type->tag != TYPE_INTEGER || + expr->binary_op.rhs->type->tag != TYPE_INTEGER) { + report_invalid_binary_args(expr, context); + return false; + } + + expr->type = typ_int; + return true; + case EXPR_TERNARY: + if (!type_check_expr(expr->ternary.cond, context) || + !type_check_expr(expr->ternary.false_expr, context) || + !type_check_expr(expr->ternary.true_expr, context)) { + MCC_ASSERT(expr->ternary.cond->type->tag == TYPE_INTEGER); + // TODO: check the two ternary branches has the same type + return false; + } + + MCC_ASSERT(expr->ternary.cond->type->tag == TYPE_INTEGER); + // TODO: check the two branches has the same type + expr->type = expr->ternary.true_expr->type; + return true; + case EXPR_CALL: return type_check_function_call(expr, context); + } +} + +static bool type_check_block(Block* block, Context* context); + +[[nodiscard]] static bool type_check_variable_decl(VariableDecl* decl, + Context* context); + +[[nodiscard]] +static bool type_check_stmt(Stmt* stmt, Context* context) +{ + switch (stmt->tag) { + case STMT_INVALID: MCC_UNREACHABLE(); + case STMT_EMPTY: return true; + case STMT_EXPR: return type_check_expr(stmt->expr, context); + case STMT_COMPOUND: return type_check_block(&stmt->compound, context); + case STMT_RETURN: { + // TODO: check it return the expect function return type + Expr* expr = stmt->ret.expr; + if (!type_check_expr(expr, context)) { return false; } + + if (expr->type->tag != TYPE_INTEGER) { + report_incompatible_return(expr, context); + return false; + } + return true; + } + case STMT_IF: { + Expr* cond = stmt->if_then.cond; + if (!type_check_expr(cond, context)) { return false; } + MCC_ASSERT(cond->type->tag == TYPE_INTEGER); + bool result = type_check_stmt(stmt->if_then.then, context); + if (stmt->if_then.els != nullptr) { + result &= type_check_stmt(stmt->if_then.els, context); + } + return result; + } + case STMT_WHILE: [[fallthrough]]; + case STMT_DO_WHILE: { + Expr* cond = stmt->while_loop.cond; + if (!type_check_expr(cond, context)) { return false; } + MCC_ASSERT(cond->type->tag == TYPE_INTEGER); + return type_check_stmt(stmt->while_loop.body, context); + } + case STMT_FOR: { + ForInit init = stmt->for_loop.init; + Expr* cond = stmt->for_loop.cond; + Stmt* body = stmt->for_loop.body; + Expr* post = stmt->for_loop.post; + + bool result = true; + switch (init.tag) { + case FOR_INIT_INVALID: MCC_UNREACHABLE(); + case FOR_INIT_DECL: + if (!type_check_variable_decl(init.decl, context)) { return false; } + break; + case FOR_INIT_EXPR: { + if (init.expr) { result &= type_check_expr(init.expr, context); } + } break; + } + + if (cond) { result &= type_check_expr(cond, context); } + result &= type_check_stmt(body, context); + if (post) { result &= type_check_expr(post, context); } + if (!result) { return result; } + + return true; + } + case STMT_BREAK: + case STMT_CONTINUE: return true; + } +} + +[[nodiscard]] static bool type_check_variable_decl(VariableDecl* decl, + Context* context) +{ + decl->name->type = typ_int; + + if (decl->initializer) { + if (!type_check_expr(decl->initializer, context)) { return false; } + + // TODO: error handling + MCC_ASSERT(decl->initializer->type->tag == TYPE_INTEGER); + } + return true; +} + +static bool type_check_block(Block* block, Context* context) +{ + bool result = true; + for (uint32_t i = 0; i < block->child_count; ++i) { + BlockItem* item = &block->children[i]; + switch (item->tag) { + case BLOCK_ITEM_STMT: + result &= type_check_stmt(&item->stmt, context); + break; + case BLOCK_ITEM_DECL: + if (!type_check_variable_decl(&item->decl, context)) { return false; } + break; + } + } + return result; +} + +static void type_check_function_decl(FunctionDecl* decl, Context* context) +{ + if (decl->body != nullptr) { + for (uint32_t i = 0; i < decl->params.length; ++i) { + Variable* param = decl->params.data[i]; + param->type = typ_int; + } + + type_check_block(decl->body, context); + } +} + +ErrorsView type_check(TranslationUnit* ast, Arena* permanent_arena) +{ + Context context = {.permanent_arena = permanent_arena}; + + for (uint32_t i = 0; i < ast->decl_count; ++i) { + type_check_function_decl(ast->decls[i], &context); + } + + return (ErrorsView){ + .data = context.errors.data, + .length = context.errors.length, + }; +} diff --git a/src/ir/ir_generator.c b/src/ir/ir_generator.c index c31bcf6..be71cef 100644 --- a/src/ir/ir_generator.c +++ b/src/ir/ir_generator.c @@ -5,6 +5,8 @@ #include #include +#include "../frontend/symbol_table.h" + #define ASSIGNMENTS \ case BINARY_OP_ASSIGNMENT: \ case BINARY_OP_PLUS_EQUAL: \ @@ -356,7 +358,7 @@ static IRValue emit_ir_instructions_from_expr(const Expr* expr, default: return emit_ir_instructions_from_binary_expr(expr, context); } } - case EXPR_VARIABLE: return ir_variable(expr->variable); + case EXPR_VARIABLE: return ir_variable(expr->variable->rewrote_name); case EXPR_TERNARY: { const StringView true_label = create_fresh_label_name(context, "ternary_true"); @@ -414,8 +416,9 @@ static void emit_ir_instructions_from_decl(const VariableDecl* decl, if (decl->initializer != nullptr) { const IRValue value = emit_ir_instructions_from_expr(decl->initializer, context); - push_instruction(context, - ir_unary_instr(IR_COPY, ir_variable(decl->name), value)); + push_instruction( + context, + ir_unary_instr(IR_COPY, ir_variable(decl->name->rewrote_name), value)); } } @@ -685,8 +688,8 @@ typedef struct IRFunctionVec { uint32_t capacity; } IRFunctionVec; -IRGenerationResult ir_generate(TranslationUnit* ast, Arena* permanent_arena, - Arena scratch_arena) +IRGenerationResult ir_generate(const TranslationUnit* ast, + Arena* permanent_arena, Arena scratch_arena) { IRFunctionVec ir_function_vec = {}; diff --git a/src/main.c b/src/main.c index c1b6bc3..6b42379 100644 --- a/src/main.c +++ b/src/main.c @@ -8,7 +8,9 @@ #include #include #include +#include #include +#include #include #include @@ -170,14 +172,21 @@ int main(int argc, char* argv[]) // Failed to parse program return 1; } - + TranslationUnit* tu = parse_result.ast; if (args.stop_after_parser) { - ast_print_translation_unit(parse_result.ast); + ast_print_translation_unit(tu); return 0; } + ErrorsView type_errors = type_check(tu, &permanent_arena); + if (type_errors.length != 0) { + print_diagnostics(type_errors, &diagnostics_context); + return 1; + } + if (args.stop_after_semantic_analysis) { return 0; } + IRGenerationResult ir_gen_result = - ir_generate(parse_result.ast, &permanent_arena, scratch_arena); + ir_generate(tu, &permanent_arena, scratch_arena); if (ir_gen_result.program == NULL) { // Failed to generate IR diff --git a/src/utils/cli_args.c b/src/utils/cli_args.c index ac71be7..9f980e1 100644 --- a/src/utils/cli_args.c +++ b/src/utils/cli_args.c @@ -16,6 +16,8 @@ static const Option options[] = { {"--help, -h", "prints this help message"}, {"--lex", "lexing only and then dump the result tokens"}, {"--parse", "lex and parse, and then dump the result AST"}, + {"--validate", + "lex, parse, perform the semantic analysis on result AST, and then stop"}, {"--ir", "generate the IR, and then dump the result AST"}, {"--codegen", "generate the assembly, and then dump the result rather than " "saving to a file"}, @@ -50,6 +52,8 @@ CliArgs parse_cli_args(int argc, char** argv) result.stop_after_lexer = true; } else if (str_eq(arg, str("--parse"))) { result.stop_after_parser = true; + } else if (str_eq(arg, str("--validate"))) { + result.stop_after_semantic_analysis = true; } else if (str_eq(arg, str("-S"))) { result.compile_only = true; } else if (str_eq(arg, str("-c"))) { @@ -64,6 +68,7 @@ CliArgs parse_cli_args(int argc, char** argv) stderr, "mcc: fatal error: unrecognized command-line option: '%.*s'\n", (int)arg.size, arg.start); + exit(1); } else { // TODO: support more than one source file result.source_filename = argv[i]; diff --git a/tests/test_data/invalid_semantics/name_resolution/test_config.toml b/tests/test_data/invalid_semantics/name_resolution/test_config.toml deleted file mode 100644 index bf66b63..0000000 --- a/tests/test_data/invalid_semantics/name_resolution/test_config.toml +++ /dev/null @@ -1,3 +0,0 @@ -command = "{mcc} --parse {filename}" -return_code = 1 -snapshot_test_stderr = true diff --git a/tests/test_data/invalid_semantics/loop/test_config.toml b/tests/test_data/invalid_semantics/test_config.toml similarity index 100% rename from tests/test_data/invalid_semantics/loop/test_config.toml rename to tests/test_data/invalid_semantics/test_config.toml diff --git a/tests/test_data/invalid_semantics/type_checking/call_int_as_function.c b/tests/test_data/invalid_semantics/type_checking/call_int_as_function.c new file mode 100644 index 0000000..65015d2 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/call_int_as_function.c @@ -0,0 +1,4 @@ +int main(void) +{ + return 1(2); +} diff --git a/tests/test_data/invalid_semantics/type_checking/call_int_as_function.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/call_int_as_function.stderr.approved.txt new file mode 100644 index 0000000..df8d3bc --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/call_int_as_function.stderr.approved.txt @@ -0,0 +1,4 @@ +{{filename}}:3:10: Error: called object with type 'int', which is not callable +3 | return 1(2); + | ^ + diff --git a/tests/test_data/invalid_semantics/type_checking/call_int_var_as_function.c b/tests/test_data/invalid_semantics/type_checking/call_int_var_as_function.c new file mode 100644 index 0000000..311a1f7 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/call_int_var_as_function.c @@ -0,0 +1,5 @@ +int main(void) +{ + int x = 42; + return x(2); +} diff --git a/tests/test_data/invalid_semantics/type_checking/call_int_var_as_function.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/call_int_var_as_function.stderr.approved.txt new file mode 100644 index 0000000..5df58d3 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/call_int_var_as_function.stderr.approved.txt @@ -0,0 +1,4 @@ +{{filename}}:4:10: Error: called object with type 'int', which is not callable +4 | return x(2); + | ^ + diff --git a/tests/test_data/invalid_semantics/type_checking/extra_parameter.c b/tests/test_data/invalid_semantics/type_checking/extra_parameter.c new file mode 100644 index 0000000..6ad696e --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/extra_parameter.c @@ -0,0 +1,9 @@ +int f(int x) +{ + return x; +} + +int main(void) +{ + return f(1, 2); +} diff --git a/tests/test_data/invalid_semantics/type_checking/extra_parameter.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/extra_parameter.stderr.approved.txt new file mode 100644 index 0000000..206641f --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/extra_parameter.stderr.approved.txt @@ -0,0 +1,4 @@ +{{filename}}:7:10: Error: too many arguments to function call, expected 1, have 2 +7 | return f(1, 2); + | ^ + diff --git a/tests/test_data/invalid_semantics/type_checking/function_in_expr.c b/tests/test_data/invalid_semantics/type_checking/function_in_expr.c new file mode 100644 index 0000000..8aff86c --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/function_in_expr.c @@ -0,0 +1,7 @@ +int f(void); + +int main(void) +{ + -f; + 1 + f; +} diff --git a/tests/test_data/invalid_semantics/type_checking/function_in_expr.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/function_in_expr.stderr.approved.txt new file mode 100644 index 0000000..69fb779 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/function_in_expr.stderr.approved.txt @@ -0,0 +1,8 @@ +{{filename}}:4:4: Error: invalid argument type 'int(void)' to unary expression +4 | -f; + | ^ + +{{filename}}:5:3: Error: invalid operands to binary expression ('int' and 'int(void)') +5 | 1 + f; + | ^~~~~ + diff --git a/tests/test_data/invalid_semantics/type_checking/passing_function_as_int.c b/tests/test_data/invalid_semantics/type_checking/passing_function_as_int.c new file mode 100644 index 0000000..5b665e8 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/passing_function_as_int.c @@ -0,0 +1,9 @@ +int f(int x, int y) +{ + return x + y; +} + +int main(void) +{ + return f(42, f); +} diff --git a/tests/test_data/invalid_semantics/type_checking/passing_function_as_int.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/passing_function_as_int.stderr.approved.txt new file mode 100644 index 0000000..4971882 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/passing_function_as_int.stderr.approved.txt @@ -0,0 +1,4 @@ +{{filename}}:7:16: Error: passing 'int(int, int)' to parameter of type 'int' +7 | return f(42, f); + | ^ + diff --git a/tests/test_data/invalid_semantics/type_checking/return_function.c b/tests/test_data/invalid_semantics/type_checking/return_function.c new file mode 100644 index 0000000..0679a68 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/return_function.c @@ -0,0 +1,6 @@ +int f(void); + +int main(void) +{ + return f; +} diff --git a/tests/test_data/invalid_semantics/type_checking/return_function.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/return_function.stderr.approved.txt new file mode 100644 index 0000000..9d74690 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/return_function.stderr.approved.txt @@ -0,0 +1,4 @@ +{{filename}}:4:10: Error: returning 'int(void)' from a function with incompatible result type 'int' +4 | return f; + | ^ + diff --git a/tests/test_data/invalid_semantics/type_checking/too_few_parameters.c b/tests/test_data/invalid_semantics/type_checking/too_few_parameters.c new file mode 100644 index 0000000..a279c51 --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/too_few_parameters.c @@ -0,0 +1,9 @@ +int f(int x) +{ + return x; +} + +int main(void) +{ + return f(); +} diff --git a/tests/test_data/invalid_semantics/type_checking/too_few_parameters.stderr.approved.txt b/tests/test_data/invalid_semantics/type_checking/too_few_parameters.stderr.approved.txt new file mode 100644 index 0000000..10cba1b --- /dev/null +++ b/tests/test_data/invalid_semantics/type_checking/too_few_parameters.stderr.approved.txt @@ -0,0 +1,4 @@ +{{filename}}:7:10: Error: too few arguments to function call, expected 1, have 0 +7 | return f(); + | ^ + diff --git a/tests/unit_tests/lexer_test.cpp b/tests/unit_tests/lexer_test.cpp index 8b17dd3..c06d954 100644 --- a/tests/unit_tests/lexer_test.cpp +++ b/tests/unit_tests/lexer_test.cpp @@ -21,49 +21,48 @@ TEST_CASE("Lexer lex symbols", "[lexer]") + ++ += - -- -= -> * *= / /= % %= , . ?:;)"; - static constexpr TokenType expected[] = {TOKEN_EQUAL, - TOKEN_EQUAL_EQUAL, - TOKEN_NOT_EQUAL, - TOKEN_LESS, - TOKEN_LESS_LESS, - TOKEN_LESS_EQUAL, - TOKEN_GREATER, - TOKEN_GREATER_GREATER, - TOKEN_GREATER_EQUAL, - TOKEN_AMPERSAND, - TOKEN_AMPERSAND_AMPERSAND, - TOKEN_AMPERSAND_EQUAL, - TOKEN_BAR, - TOKEN_BAR_BAR, - TOKEN_BAR_EQUAL, - TOKEN_NOT, - TOKEN_CARET, - TOKEN_CARET_EQUAL, - TOKEN_LESS_LESS_EQUAL, - TOKEN_GREATER_GREATER_EQUAL, - TOKEN_PLUS, - TOKEN_PLUS_PLUS, - TOKEN_PLUS_EQUAL, - TOKEN_MINUS, - TOKEN_MINUS_MINUS, - TOKEN_MINUS_EQUAL, - TOKEN_MINUS_GREATER, - TOKEN_STAR, - TOKEN_STAR_EQUAL, - TOKEN_SLASH, - TOKEN_SLASH_EQUAL, - TOKEN_PERCENT, - TOKEN_PERCENT_EQUAL, - TOKEN_COMMA, - TOKEN_DOT, - TOKEN_QUESTION, - TOKEN_COLON, - TOKEN_SEMICOLON, - TOKEN_EOF}; + static constexpr TokenTag expected[] = {TOKEN_EQUAL, + TOKEN_EQUAL_EQUAL, + TOKEN_NOT_EQUAL, + TOKEN_LESS, + TOKEN_LESS_LESS, + TOKEN_LESS_EQUAL, + TOKEN_GREATER, + TOKEN_GREATER_GREATER, + TOKEN_GREATER_EQUAL, + TOKEN_AMPERSAND, + TOKEN_AMPERSAND_AMPERSAND, + TOKEN_AMPERSAND_EQUAL, + TOKEN_BAR, + TOKEN_BAR_BAR, + TOKEN_BAR_EQUAL, + TOKEN_NOT, + TOKEN_CARET, + TOKEN_CARET_EQUAL, + TOKEN_LESS_LESS_EQUAL, + TOKEN_GREATER_GREATER_EQUAL, + TOKEN_PLUS, + TOKEN_PLUS_PLUS, + TOKEN_PLUS_EQUAL, + TOKEN_MINUS, + TOKEN_MINUS_MINUS, + TOKEN_MINUS_EQUAL, + TOKEN_MINUS_GREATER, + TOKEN_STAR, + TOKEN_STAR_EQUAL, + TOKEN_SLASH, + TOKEN_SLASH_EQUAL, + TOKEN_PERCENT, + TOKEN_PERCENT_EQUAL, + TOKEN_COMMA, + TOKEN_DOT, + TOKEN_QUESTION, + TOKEN_COLON, + TOKEN_SEMICOLON, + TOKEN_EOF}; const auto tokens = lex(input, &permanent_arena, scratch_arena); - const std::span token_types(tokens.token_types, - tokens.token_count); + const std::span token_types(tokens.token_types, tokens.token_count); REQUIRE_THAT(expected, RangeEquals(token_types)); } @@ -77,22 +76,21 @@ TEST_CASE("Lexer lex keywords", "[lexer]") do while for break continue let)"; - static constexpr TokenType expected[] = {TOKEN_KEYWORD_INT, - TOKEN_KEYWORD_VOID, - TOKEN_KEYWORD_RETURN, - TOKEN_KEYWORD_TYPEDEF, - TOKEN_KEYWORD_IF, - TOKEN_KEYWORD_ELSE, - TOKEN_KEYWORD_DO, - TOKEN_KEYWORD_WHILE, - TOKEN_KEYWORD_FOR, - TOKEN_KEYWORD_BREAK, - TOKEN_KEYWORD_CONTINUE, - TOKEN_IDENTIFIER, - TOKEN_EOF}; + static constexpr TokenTag expected[] = {TOKEN_KEYWORD_INT, + TOKEN_KEYWORD_VOID, + TOKEN_KEYWORD_RETURN, + TOKEN_KEYWORD_TYPEDEF, + TOKEN_KEYWORD_IF, + TOKEN_KEYWORD_ELSE, + TOKEN_KEYWORD_DO, + TOKEN_KEYWORD_WHILE, + TOKEN_KEYWORD_FOR, + TOKEN_KEYWORD_BREAK, + TOKEN_KEYWORD_CONTINUE, + TOKEN_IDENTIFIER, + TOKEN_EOF}; const auto tokens = lex(input, &permanent_arena, scratch_arena); - const std::span token_types(tokens.token_types, - tokens.token_count); + const std::span token_types(tokens.token_types, tokens.token_count); REQUIRE_THAT(expected, RangeEquals(token_types)); }