From 89d8493192feee7c7acaa52924a265bde745a32d Mon Sep 17 00:00:00 2001 From: Mattias Hansson Date: Sat, 1 Nov 2025 18:50:01 +0100 Subject: [PATCH] Added compiler, command and the BYTE command for a first iteration. --- internal/commands/byte.go | 157 ++++++++++ internal/commands/byte_test.go | 440 +++++++++++++++++++++++++++++ internal/compiler/command.go | 59 ++++ internal/compiler/compiler.go | 116 ++++++++ internal/compiler/compiler_test.go | 148 ++++++++++ internal/compiler/context.go | 55 ++++ internal/utils/utils.go | 263 +++++++++++++++++ internal/utils/utils_test.go | 272 ++++++++++++++++++ 8 files changed, 1510 insertions(+) create mode 100644 internal/commands/byte.go create mode 100644 internal/commands/byte_test.go create mode 100644 internal/compiler/command.go create mode 100644 internal/compiler/compiler.go create mode 100644 internal/compiler/compiler_test.go create mode 100644 internal/compiler/context.go create mode 100644 internal/utils/utils.go create mode 100644 internal/utils/utils_test.go diff --git a/internal/commands/byte.go b/internal/commands/byte.go new file mode 100644 index 0000000..74b8030 --- /dev/null +++ b/internal/commands/byte.go @@ -0,0 +1,157 @@ +package commands + +import ( + "fmt" + "strings" + + "c65gm/internal/compiler" + "c65gm/internal/preproc" + "c65gm/internal/utils" +) + +// ByteCommand handles BYTE variable declarations +// Syntax: +// +// BYTE varname # byte with init = 0 +// BYTE varname = value # byte with init value +// BYTE varname @ address # byte at absolute address +// BYTE CONST varname = value # constant byte +type ByteCommand struct { + varName string + value uint16 + isConst bool + isAbs bool +} + +func (c *ByteCommand) WillHandle(line preproc.Line) bool { + params, err := utils.ParseParams(line.Text) + if err != nil || len(params) == 0 { + return false + } + return strings.ToUpper(params[0]) == "BYTE" +} + +func (c *ByteCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error { + // Clear state + c.varName = "" + c.value = 0 + c.isConst = false + c.isAbs = false + + params, err := utils.ParseParams(line.Text) + if err != nil { + return err + } + + paramCount := len(params) + + // Validate parameter count + if paramCount != 2 && paramCount != 4 && paramCount != 5 { + return fmt.Errorf("BYTE: wrong number of parameters (%d)", paramCount) + } + + var varName string + var value int64 + scope := ctx.FunctionHandler.CurrentFunction() + + // Create constant lookup function + constLookup := func(name string) (int64, bool) { + sym := ctx.SymbolTable.Lookup(name, ctx.CurrentScope()) + if sym != nil && sym.IsConst() { + return int64(sym.Value), true + } + return 0, false + } + + switch paramCount { + case 2: + // BYTE varname + varName = params[1] + value = 0 + + if !utils.ValidateIdentifier(varName) { + return fmt.Errorf("BYTE: invalid identifier %q", varName) + } + + err = ctx.SymbolTable.AddVar(varName, scope, compiler.KindByte, uint16(value)) + + case 4: + // BYTE varname = value OR BYTE varname @ address + varName = params[1] + operator := params[2] + valueStr := params[3] + + if !utils.ValidateIdentifier(varName) { + return fmt.Errorf("BYTE: invalid identifier %q", varName) + } + + value, err = utils.EvaluateExpression(valueStr, constLookup) + if err != nil { + return fmt.Errorf("BYTE: invalid value %q: %w", valueStr, err) + } + + if operator == "=" { + // BYTE varname = value + if value < 0 || value > 255 { + return fmt.Errorf("BYTE: init value %d out of range (0-255)", value) + } + err = ctx.SymbolTable.AddVar(varName, scope, compiler.KindByte, uint16(value)) + + } else if operator == "@" { + // BYTE varname @ address + if value < 0 || value > 0xFFFF { + return fmt.Errorf("BYTE: absolute address $%X out of range", value) + } + c.isAbs = true + err = ctx.SymbolTable.AddAbsolute(varName, scope, compiler.KindByte, uint16(value)) + + } else { + return fmt.Errorf("BYTE: expected '=' or '@', got %q", operator) + } + + case 5: + // BYTE CONST varname = value + constKeyword := strings.ToUpper(params[1]) + varName = params[2] + operator := params[3] + valueStr := params[4] + + if constKeyword != "CONST" { + return fmt.Errorf("BYTE: expected CONST keyword, got %q", params[1]) + } + + if operator != "=" { + return fmt.Errorf("BYTE: expected '=', got %q", operator) + } + + if !utils.ValidateIdentifier(varName) { + return fmt.Errorf("BYTE: invalid identifier %q", varName) + } + + value, err = utils.EvaluateExpression(valueStr, constLookup) + if err != nil { + return fmt.Errorf("BYTE: invalid value %q: %w", valueStr, err) + } + + if value < 0 || value > 255 { + return fmt.Errorf("BYTE: const value %d out of range (0-255)", value) + } + + c.isConst = true + err = ctx.SymbolTable.AddConst(varName, scope, compiler.KindByte, uint16(value)) + } + + if err != nil { + return fmt.Errorf("BYTE: %w", err) + } + + c.varName = varName + c.value = uint16(value) + + return nil +} + +func (c *ByteCommand) Generate(_ *compiler.CompilerContext) ([]string, error) { + // Variables are rendered by assembleOutput, not by individual commands + return nil, nil +} diff --git a/internal/commands/byte_test.go b/internal/commands/byte_test.go new file mode 100644 index 0000000..16874fd --- /dev/null +++ b/internal/commands/byte_test.go @@ -0,0 +1,440 @@ +package commands + +import ( + "strings" + "testing" + + "c65gm/internal/compiler" + "c65gm/internal/preproc" +) + +func TestByteCommand_WillHandle(t *testing.T) { + tests := []struct { + name string + text string + want bool + }{ + { + name: "handles BYTE", + text: "BYTE x", + want: true, + }, + { + name: "handles byte lowercase", + text: "byte x", + want: true, + }, + { + name: "handles BYTE with init", + text: "BYTE x = 10", + want: true, + }, + { + name: "handles BYTE at absolute", + text: "BYTE x @ $C000", + want: true, + }, + { + name: "handles BYTE CONST", + text: "BYTE CONST x = 10", + want: true, + }, + { + name: "does not handle WORD", + text: "WORD x", + want: false, + }, + { + name: "does not handle empty", + text: "", + want: false, + }, + { + name: "does not handle other command", + text: "LET x 5", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &ByteCommand{} + line := preproc.Line{Text: tt.text} + got := cmd.WillHandle(line) + if got != tt.want { + t.Errorf("WillHandle() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestByteCommand_Interpret(t *testing.T) { + tests := []struct { + name string + text string + wantErr bool + errContains string + checkVar func(*testing.T, *compiler.CompilerContext) + }{ + { + name: "simple byte", + text: "BYTE x", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("x", nil) + if sym == nil { + t.Fatal("Expected variable x to be declared") + } + if !sym.IsByte() { + t.Error("Expected byte variable") + } + if sym.IsConst() { + t.Error("Expected regular variable, not const") + } + if sym.Value != 0 { + t.Errorf("Expected init value 0, got %d", sym.Value) + } + }, + }, + { + name: "byte with decimal init", + text: "BYTE counter = 42", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("counter", nil) + if sym == nil { + t.Fatal("Expected variable counter to be declared") + } + if sym.Value != 42 { + t.Errorf("Expected init value 42, got %d", sym.Value) + } + }, + }, + { + name: "byte with hex init", + text: "BYTE status = $FF", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("status", nil) + if sym == nil { + t.Fatal("Expected variable status to be declared") + } + if sym.Value != 255 { + t.Errorf("Expected init value 255, got %d", sym.Value) + } + }, + }, + { + name: "byte at absolute address", + text: "BYTE ptr @ $C000", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("ptr", nil) + if sym == nil { + t.Fatal("Expected variable ptr to be declared") + } + if !sym.IsAbsolute() { + t.Error("Expected absolute variable") + } + if sym.AbsAddr != 0xC000 { + t.Errorf("Expected address $C000, got $%04X", sym.AbsAddr) + } + }, + }, + { + name: "byte at zero page", + text: "BYTE zpvar @ $80", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("zpvar", nil) + if sym == nil { + t.Fatal("Expected variable zpvar to be declared") + } + if !sym.IsZeroPage() { + t.Error("Expected zero page variable") + } + }, + }, + { + name: "const byte", + text: "BYTE CONST maxval = 255", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("maxval", nil) + if sym == nil { + t.Fatal("Expected constant maxval to be declared") + } + if !sym.IsConst() { + t.Error("Expected constant") + } + if sym.Value != 255 { + t.Errorf("Expected value 255, got %d", sym.Value) + } + }, + }, + { + name: "const byte with hex", + text: "BYTE CONST flag = $FF", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("flag", nil) + if sym == nil { + t.Fatal("Expected constant flag to be declared") + } + if !sym.IsConst() { + t.Error("Expected constant") + } + if sym.Value != 255 { + t.Errorf("Expected value 255, got %d", sym.Value) + } + }, + }, + { + name: "byte with expression", + text: "BYTE x = 10+20", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("x", nil) + if sym == nil { + t.Fatal("Expected variable x to be declared") + } + if sym.Value != 30 { + t.Errorf("Expected value 30, got %d", sym.Value) + } + }, + }, + { + name: "byte with binary", + text: "BYTE x = !11111111", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("x", nil) + if sym == nil { + t.Fatal("Expected variable x to be declared") + } + if sym.Value != 255 { + t.Errorf("Expected value 255, got %d", sym.Value) + } + }, + }, + { + name: "byte with bitwise OR", + text: "BYTE x = $F0|$0F", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("x", nil) + if sym == nil { + t.Fatal("Expected variable x to be declared") + } + if sym.Value != 255 { + t.Errorf("Expected value 255, got %d", sym.Value) + } + }, + }, + { + name: "byte with bitwise AND", + text: "BYTE x = $FF&$0F", + wantErr: false, + checkVar: func(t *testing.T, ctx *compiler.CompilerContext) { + sym := ctx.SymbolTable.Lookup("x", nil) + if sym == nil { + t.Fatal("Expected variable x to be declared") + } + if sym.Value != 15 { + t.Errorf("Expected value 15, got %d", sym.Value) + } + }, + }, + { + name: "byte with out of range value", + text: "BYTE x = 256", + wantErr: true, + errContains: "out of range", + }, + { + name: "const byte out of range", + text: "BYTE CONST x = 256", + wantErr: true, + errContains: "out of range", + }, + { + name: "byte without name", + text: "BYTE", + wantErr: true, + errContains: "wrong number of parameters", + }, + { + name: "byte with invalid identifier", + text: "BYTE 123invalid", + wantErr: true, + errContains: "invalid identifier", + }, + { + name: "byte with wrong operator", + text: "BYTE x + 10", + wantErr: true, + errContains: "expected '=' or '@'", + }, + { + name: "const without equals", + text: "BYTE CONST x 10", + wantErr: true, + errContains: "expected '='", + }, + { + name: "wrong keyword instead of CONST", + text: "BYTE VAR x = 10", + wantErr: true, + errContains: "expected CONST keyword", + }, + { + name: "duplicate declaration", + text: "BYTE x", + wantErr: true, + errContains: "already declared", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + // For duplicate test, pre-declare the variable + if tt.name == "duplicate declaration" { + ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0) + } + + cmd := &ByteCommand{} + line := preproc.Line{ + Text: tt.text, + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + err := cmd.Interpret(line, ctx) + + if tt.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("Error %q does not contain %q", err.Error(), tt.errContains) + } + } else { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if tt.checkVar != nil { + tt.checkVar(t, ctx) + } + } + }) + } +} + +func TestByteCommand_Generate(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + cmd := &ByteCommand{} + line := preproc.Line{ + Text: "BYTE x = 10", + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + PragmaSetIndex: 0, + } + + // Interpret first + if err := cmd.Interpret(line, ctx); err != nil { + t.Fatalf("Interpret failed: %v", err) + } + + // Generate should return nil (variables handled by assembleOutput) + output, err := cmd.Generate(ctx) + if err != nil { + t.Errorf("Generate returned error: %v", err) + } + if output != nil { + t.Errorf("Generate should return nil, got %v", output) + } +} + +func TestByteCommand_InFunctionScope(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + // Simulate being inside a function + // We'd need to push function context, but for now we can test the scoping manually + scope := "myFunc" + ctx.SymbolTable.AddVar("localVar", scope, compiler.KindByte, 5) + + // Check it was added with correct scope + sym := ctx.SymbolTable.Lookup("localVar", []string{scope}) + if sym == nil { + t.Fatal("Expected local variable to be declared") + } + if sym.Scope != scope { + t.Errorf("Expected scope %q, got %q", scope, sym.Scope) + } + if sym.FullName() != "myFunc_localVar" { + t.Errorf("Expected full name myFunc_localVar, got %q", sym.FullName()) + } +} + +func TestByteCommand_WithConstantExpression(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + // First, declare a constant + ctx.SymbolTable.AddConst("MAXVAL", "", compiler.KindByte, 200) + ctx.SymbolTable.AddConst("OFFSET", "", compiler.KindByte, 10) + + // Now declare a byte using the constant in an expression + cmd := &ByteCommand{} + line := preproc.Line{ + Text: "BYTE x = MAXVAL+OFFSET", + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + PragmaSetIndex: 0, + } + + if err := cmd.Interpret(line, ctx); err != nil { + t.Fatalf("Interpret failed: %v", err) + } + + // Check value + sym := ctx.SymbolTable.Lookup("x", nil) + if sym == nil { + t.Fatal("Expected variable x to be declared") + } + if sym.Value != 210 { + t.Errorf("Expected value 210 (200+10), got %d", sym.Value) + } +} + +func TestByteCommand_ConstantNotFound(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + cmd := &ByteCommand{} + line := preproc.Line{ + Text: "BYTE x = UNKNOWN", + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + PragmaSetIndex: 0, + } + + err := cmd.Interpret(line, ctx) + if err == nil { + t.Fatal("Expected error for unknown constant") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("Error should mention constant not found, got: %v", err) + } +} diff --git a/internal/compiler/command.go b/internal/compiler/command.go new file mode 100644 index 0000000..3701646 --- /dev/null +++ b/internal/compiler/command.go @@ -0,0 +1,59 @@ +package compiler + +import ( + "fmt" + + "c65gm/internal/preproc" +) + +// Command represents a script command that can interpret source lines and generate assembly +type Command interface { + // WillHandle checks if this command can handle the given line + // Should clear any internal state and return true if it handles this line + WillHandle(line preproc.Line) bool + + // Interpret parses and validates the line, storing state in the command + // Returns error if line is malformed or invalid + Interpret(line preproc.Line, ctx *CompilerContext) error + + // Generate produces assembly output based on previously interpreted line + // Returns assembly lines and error if generation fails + Generate(ctx *CompilerContext) ([]string, error) +} + +// CommandRegistry manages registered commands and dispatches lines to appropriate handlers +type CommandRegistry struct { + commands []Command +} + +// NewCommandRegistry creates a new command registry +func NewCommandRegistry() *CommandRegistry { + return &CommandRegistry{ + commands: make([]Command, 0), + } +} + +// Register adds a command to the registry +func (r *CommandRegistry) Register(cmd Command) { + r.commands = append(r.commands, cmd) +} + +// FindHandler finds the first command that will handle the given line +// Returns the command and true if found, nil and false otherwise +func (r *CommandRegistry) FindHandler(line preproc.Line) (Command, bool) { + for _, cmd := range r.commands { + if cmd.WillHandle(line) { + return cmd, true + } + } + return nil, false +} + +// UnhandledLineError represents an error when no command handles a source line +type UnhandledLineError struct { + Line preproc.Line +} + +func (e *UnhandledLineError) Error() string { + return fmt.Sprintf("%s:%d: unhandled line: %s", e.Line.Filename, e.Line.LineNo, e.Line.Text) +} diff --git a/internal/compiler/compiler.go b/internal/compiler/compiler.go new file mode 100644 index 0000000..4ddccbe --- /dev/null +++ b/internal/compiler/compiler.go @@ -0,0 +1,116 @@ +package compiler + +import ( + "fmt" + "strings" + + "c65gm/internal/preproc" +) + +// Compiler orchestrates the compilation process +type Compiler struct { + ctx *CompilerContext + registry *CommandRegistry +} + +// NewCompiler creates a new compiler with initialized context and registry +func NewCompiler(pragma *preproc.Pragma) *Compiler { + return &Compiler{ + ctx: NewCompilerContext(pragma), + registry: NewCommandRegistry(), + } +} + +// Context returns the compiler context (for registering commands that need it) +func (c *Compiler) Context() *CompilerContext { + return c.ctx +} + +// Registry returns the command registry (for registering commands) +func (c *Compiler) Registry() *CommandRegistry { + return c.registry +} + +// Compile processes preprocessed lines and generates assembly output +func (c *Compiler) Compile(lines []preproc.Line) ([]string, error) { + var codeOutput []string + + for _, line := range lines { + // Skip non-source lines (assembler and script handled differently) + if line.Kind != preproc.Source { + if line.Kind == preproc.Assembler { + // Pass through assembler lines verbatim + codeOutput = append(codeOutput, line.Text) + } + // Script lines ignored for now + continue + } + + // Skip empty/whitespace-only lines + if strings.TrimSpace(line.Text) == "" { + continue + } + + // Find handler for this line + cmd, found := c.registry.FindHandler(line) + if !found { + return nil, &UnhandledLineError{Line: line} + } + + // Interpret the line + if err := cmd.Interpret(line, c.ctx); err != nil { + return nil, fmt.Errorf("%s:%d: %w", line.Filename, line.LineNo, err) + } + + // Generate assembly + asmLines, err := cmd.Generate(c.ctx) + if err != nil { + return nil, fmt.Errorf("%s:%d: %w", line.Filename, line.LineNo, err) + } + + codeOutput = append(codeOutput, asmLines...) + } + + // Assemble final output with headers and footers + return c.assembleOutput(codeOutput), nil +} + +// assembleOutput combines all generated sections into final assembly +func (c *Compiler) assembleOutput(codeLines []string) []string { + var output []string + + // Header comment + output = append(output, ";Generated by c65gm") + output = append(output, "") + + // Constants section + if constLines := GenerateConstants(c.ctx.SymbolTable); len(constLines) > 0 { + output = append(output, constLines...) + } + + // Absolute addresses section + if absLines := GenerateAbsolutes(c.ctx.SymbolTable); len(absLines) > 0 { + output = append(output, absLines...) + } + + // Main code section + output = append(output, ";Main code") + output = append(output, "") + output = append(output, codeLines...) + output = append(output, "") + + // Variables section + if varLines := GenerateVariables(c.ctx.SymbolTable); len(varLines) > 0 { + output = append(output, varLines...) + } + + // Constant strings section + if strLines := c.ctx.ConstStrHandler.GenerateConstStrDecls(); len(strLines) > 0 { + output = append(output, ";Constant strings (from c65gm)") + output = append(output, "") + output = append(output, strLines...) + output = append(output, "") + } + + return output +} diff --git a/internal/compiler/compiler_test.go b/internal/compiler/compiler_test.go new file mode 100644 index 0000000..21202c0 --- /dev/null +++ b/internal/compiler/compiler_test.go @@ -0,0 +1,148 @@ +package compiler + +import ( + "fmt" + "strings" + "testing" + + "c65gm/internal/preproc" + "c65gm/internal/utils" +) + +// TestBreakCommand is a simple command implementation for testing +type TestBreakCommand struct { + line preproc.Line +} + +func (c *TestBreakCommand) WillHandle(line preproc.Line) bool { + params, err := utils.ParseParams(line.Text) + if err != nil || len(params) == 0 { + return false + } + return strings.ToUpper(params[0]) == "BREAK" +} + +func (c *TestBreakCommand) Interpret(line preproc.Line, ctx *CompilerContext) error { + c.line = line + + params, err := utils.ParseParams(line.Text) + if err != nil { + return err + } + + if len(params) != 1 { + return fmt.Errorf("BREAK does not expect parameters") + } + + return nil +} + +func (c *TestBreakCommand) Generate(ctx *CompilerContext) ([]string, error) { + // BREAK jumps to end of WHILE loop + label, err := ctx.WhileStack.Peek() + if err != nil { + return nil, fmt.Errorf("BREAK outside of WHILE loop") + } + + return []string{ + fmt.Sprintf(" jmp %s_end", label), + }, nil +} + +func TestCompilerArchitecture(t *testing.T) { + // Create pragma + pragma := preproc.NewPragma() + + // Create compiler + comp := NewCompiler(pragma) + + // Register BREAK command + comp.Registry().Register(&TestBreakCommand{}) + + // Create test input - BREAK inside a simulated WHILE + lines := []preproc.Line{ + { + Text: "BREAK", + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + PragmaSetIndex: 0, + }, + } + + // Manually push a WHILE label so BREAK has something to reference + comp.Context().WhileStack.Push() + + // Compile + output, err := comp.Compile(lines) + + // Should fail because BREAK needs proper WHILE context + // But this tests the basic flow: WillHandle -> Interpret -> Generate + if err != nil { + t.Logf("Expected controlled error: %v", err) + } + + // Check we got some output structure + if len(output) == 0 { + t.Logf("Got output lines: %d", len(output)) + } + + t.Logf("Output:\n%s", strings.Join(output, "\n")) +} + +func TestCommandRegistry(t *testing.T) { + registry := NewCommandRegistry() + + breakCmd := &TestBreakCommand{} + registry.Register(breakCmd) + + line := preproc.Line{ + Text: "BREAK", + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + } + + cmd, found := registry.FindHandler(line) + if !found { + t.Fatal("Expected to find BREAK handler") + } + + if cmd != breakCmd { + t.Fatal("Expected to get same command instance") + } +} + +func TestCompilerContext(t *testing.T) { + pragma := preproc.NewPragma() + ctx := NewCompilerContext(pragma) + + // Test that all resources are initialized + if ctx.SymbolTable == nil { + t.Error("SymbolTable not initialized") + } + if ctx.FunctionHandler == nil { + t.Error("FunctionHandler not initialized") + } + if ctx.ConstStrHandler == nil { + t.Error("ConstStrHandler not initialized") + } + if ctx.WhileStack == nil { + t.Error("WhileStack not initialized") + } + if ctx.IfStack == nil { + t.Error("IfStack not initialized") + } + if ctx.GeneralStack == nil { + t.Error("GeneralStack not initialized") + } + if ctx.Pragma == nil { + t.Error("Pragma not initialized") + } + + // Test CurrentScope + scope := ctx.CurrentScope() + if scope != nil { + t.Errorf("Expected nil scope in global context, got %v", scope) + } +} diff --git a/internal/compiler/context.go b/internal/compiler/context.go new file mode 100644 index 0000000..5e7a458 --- /dev/null +++ b/internal/compiler/context.go @@ -0,0 +1,55 @@ +package compiler + +import ( + "c65gm/internal/preproc" +) + +// CompilerContext holds all shared resources needed by commands during compilation +type CompilerContext struct { + // Symbol table for variables and constants + SymbolTable *SymbolTable + + // Function handler for FUNC/CALL/FEND commands + FunctionHandler *FunctionHandler + + // Constant string handler + ConstStrHandler *ConstantStringHandler + + // Label stacks for control flow + WhileStack *LabelStack // WHILE...WEND + IfStack *LabelStack // IF...ENDIF + GeneralStack *LabelStack // General purpose (GOSUB, etc) + + // Pragma access for per-line pragma lookup + Pragma *preproc.Pragma +} + +// NewCompilerContext creates a new compiler context with initialized resources +func NewCompilerContext(pragma *preproc.Pragma) *CompilerContext { + symTable := NewSymbolTable() + constStrHandler := NewConstantStringHandler() + generalStack := NewLabelStack("_L") + + ctx := &CompilerContext{ + SymbolTable: symTable, + ConstStrHandler: constStrHandler, + WhileStack: NewLabelStack("_W"), + IfStack: NewLabelStack("_I"), + GeneralStack: generalStack, + Pragma: pragma, + } + + // FunctionHandler needs references to other components + ctx.FunctionHandler = NewFunctionHandler(symTable, generalStack, constStrHandler, pragma) + + return ctx +} + +// CurrentScope returns the current function scope(s) for symbol resolution +func (ctx *CompilerContext) CurrentScope() []string { + funcName := ctx.FunctionHandler.CurrentFunction() + if funcName == "" { + return nil + } + return []string{funcName} +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 0000000..8788bb1 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,263 @@ +package utils + +import ( + "fmt" + "strconv" + "strings" +) + +// ParseParams splits a line into space-separated parameters, respecting quoted strings +func ParseParams(s string) ([]string, error) { + s = strings.TrimSpace(s) + if s == "" { + return []string{}, nil + } + + var params []string + var current strings.Builder + inString := false + + for i := 0; i < len(s); i++ { + ch := s[i] + + if ch == '"' { + inString = !inString + current.WriteByte(ch) + continue + } + + if !inString && (ch == ' ' || ch == '\t') { + if current.Len() > 0 { + params = append(params, current.String()) + current.Reset() + } + } else { + current.WriteByte(ch) + } + } + + if current.Len() > 0 { + params = append(params, current.String()) + } + + if inString { + return nil, fmt.Errorf("unterminated string") + } + + return params, nil +} + +// NormalizeSpaces reduces multiple spaces to single space, respecting quoted strings +func NormalizeSpaces(s string) string { + s = strings.TrimSpace(s) + var result strings.Builder + inString := false + lastWasSpace := false + + for i := 0; i < len(s); i++ { + ch := s[i] + + if ch == '"' { + inString = !inString + result.WriteByte(ch) + lastWasSpace = false + continue + } + + if !inString { + if ch == ' ' || ch == '\t' { + if !lastWasSpace { + result.WriteByte(' ') + lastWasSpace = true + } + } else { + result.WriteByte(ch) + lastWasSpace = false + } + } else { + result.WriteByte(ch) + lastWasSpace = false + } + } + + return result.String() +} + +// IsStringLiteral checks if s is a quoted string +func IsStringLiteral(s string) bool { + l := len(s) + return l >= 2 && s[0] == '"' && s[l-1] == '"' +} + +// StripQuotes removes surrounding quotes from a string literal +func StripQuotes(s string) string { + if IsStringLiteral(s) { + return s[1 : len(s)-1] + } + return s +} + +// ToUpper converts string to uppercase +func ToUpper(s string) string { + return strings.ToUpper(s) +} + +// ValidateIdentifier checks if s is a valid identifier (starts with letter/underscore, continues with alphanumeric/underscore) +func ValidateIdentifier(s string) bool { + if len(s) == 0 { + return false + } + + first := s[0] + if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') { + return false + } + + for i := 1; i < len(s); i++ { + ch := s[i] + if !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_') { + return false + } + } + + return true +} + +// ConstantLookup is a function type for looking up constant values by name +type ConstantLookup func(name string) (value int64, found bool) + +// EvaluateExpression evaluates a simple left-to-right expression +// Supports: +// - Decimal: 123 +// - Hex: $FF +// - Binary: !11111111 +// - Constants: MAXVAL (via lookup function) +// - Operators: +, -, *, /, | (binary OR), & (binary AND) +// +// Returns value and error. If lookup is nil, constants are not supported. +func EvaluateExpression(expr string, lookup ConstantLookup) (int64, error) { + expr = strings.TrimSpace(expr) + if expr == "" { + return 0, fmt.Errorf("empty expression") + } + + // Split expression into terms and operators + var terms []string + var operators []rune + var currentTerm strings.Builder + + for i := 0; i < len(expr); i++ { + ch := expr[i] + + if isOperator(ch) { + // Save current term + if currentTerm.Len() > 0 { + terms = append(terms, strings.TrimSpace(currentTerm.String())) + currentTerm.Reset() + } + operators = append(operators, rune(ch)) + } else { + currentTerm.WriteByte(ch) + } + } + + // Save final term + if currentTerm.Len() > 0 { + terms = append(terms, strings.TrimSpace(currentTerm.String())) + } + + if len(terms) == 0 { + return 0, fmt.Errorf("no terms in expression") + } + + if len(operators) != len(terms)-1 { + return 0, fmt.Errorf("mismatched operators and terms") + } + + // Evaluate first term + result, err := evaluateTerm(terms[0], lookup) + if err != nil { + return 0, fmt.Errorf("term %q: %w", terms[0], err) + } + + // Apply operators left-to-right + for i, op := range operators { + nextVal, err := evaluateTerm(terms[i+1], lookup) + if err != nil { + return 0, fmt.Errorf("term %q: %w", terms[i+1], err) + } + + switch op { + case '+': + result = result + nextVal + case '-': + result = result - nextVal + case '*': + result = result * nextVal + case '/': + if nextVal == 0 { + return 0, fmt.Errorf("division by zero") + } + result = result / nextVal + case '|': + result = result | nextVal + case '&': + result = result & nextVal + default: + return 0, fmt.Errorf("unknown operator %q", op) + } + } + + return result, nil +} + +// isOperator checks if a character is an operator +func isOperator(ch byte) bool { + return ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '|' || ch == '&' +} + +// evaluateTerm evaluates a single term (number or constant) +func evaluateTerm(term string, lookup ConstantLookup) (int64, error) { + term = strings.TrimSpace(term) + if term == "" { + return 0, fmt.Errorf("empty term") + } + + // Check for hex: $FF + if strings.HasPrefix(term, "$") { + val, err := strconv.ParseInt(term[1:], 16, 64) + if err != nil { + return 0, fmt.Errorf("invalid hex number: %w", err) + } + return val, nil + } + + // Check for binary: !11111111 + if strings.HasPrefix(term, "!") { + val, err := strconv.ParseInt(term[1:], 2, 64) + if err != nil { + return 0, fmt.Errorf("invalid binary number: %w", err) + } + return val, nil + } + + // Check for identifier (constant) + first := term[0] + if (first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_' { + if lookup == nil { + return 0, fmt.Errorf("constant %q not supported (no lookup function)", term) + } + val, found := lookup(term) + if !found { + return 0, fmt.Errorf("constant %q not found", term) + } + return val, nil + } + + // Decimal number + val, err := strconv.ParseInt(term, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid decimal number: %w", err) + } + return val, nil +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 0000000..802d4ab --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,272 @@ +package utils + +import ( + "reflect" + "testing" +) + +func TestParseParams(t *testing.T) { + tests := []struct { + name string + input string + want []string + wantErr bool + }{ + { + name: "simple params", + input: "BYTE x 10", + want: []string{"BYTE", "x", "10"}, + }, + { + name: "with quoted string", + input: `PRINT "hello world"`, + want: []string{"PRINT", `"hello world"`}, + }, + { + name: "multiple spaces", + input: "WORD y 20", + want: []string{"WORD", "y", "20"}, + }, + { + name: "empty string", + input: "", + want: []string{}, + }, + { + name: "only spaces", + input: " ", + want: []string{}, + }, + { + name: "unterminated string", + input: `PRINT "hello`, + wantErr: true, + }, + { + name: "string with spaces inside", + input: `CALL print ( "hello world" )`, + want: []string{"CALL", "print", "(", `"hello world"`, ")"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseParams(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseParams() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseParams() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNormalizeSpaces(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "multiple spaces", + input: "BYTE x 10", + want: "BYTE x 10", + }, + { + name: "preserves string spaces", + input: `PRINT "hello world"`, + want: `PRINT "hello world"`, + }, + { + name: "leading trailing spaces", + input: " BYTE x ", + want: "BYTE x", + }, + { + name: "tabs", + input: "BYTE\t\tx", + want: "BYTE x", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeSpaces(tt.input) + if got != tt.want { + t.Errorf("NormalizeSpaces() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestIsStringLiteral(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {`"hello"`, true}, + {`"hello world"`, true}, + {`""`, true}, + {`"`, false}, + {`hello`, false}, + {`'hello'`, false}, + {`hello"`, false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := IsStringLiteral(tt.input) + if got != tt.want { + t.Errorf("IsStringLiteral(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestStripQuotes(t *testing.T) { + tests := []struct { + input string + want string + }{ + {`"hello"`, "hello"}, + {`"hello world"`, "hello world"}, + {`""`, ""}, + {`hello`, "hello"}, + {`"`, `"`}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := StripQuotes(tt.input) + if got != tt.want { + t.Errorf("StripQuotes(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestValidateIdentifier(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"x", true}, + {"_temp", true}, + {"myVar123", true}, + {"_", true}, + {"var_name", true}, + {"123abc", false}, + {"", false}, + {"my-var", false}, + {"my var", false}, + {"my.var", false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := ValidateIdentifier(tt.input) + if got != tt.want { + t.Errorf("ValidateIdentifier(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestEvaluateExpression(t *testing.T) { + // Simple constant lookup for testing + constLookup := func(name string) (int64, bool) { + constants := map[string]int64{ + "MAXVAL": 255, + "BASE": 100, + "OFFSET": 10, + } + val, ok := constants[name] + return val, ok + } + + tests := []struct { + name string + expr string + lookup ConstantLookup + want int64 + wantErr bool + }{ + // Decimal + {name: "simple decimal", expr: "42", want: 42}, + {name: "decimal zero", expr: "0", want: 0}, + + // Hex + {name: "hex lowercase", expr: "$ff", want: 255}, + {name: "hex uppercase", expr: "$FF", want: 255}, + {name: "hex zero", expr: "$00", want: 0}, + {name: "hex word", expr: "$C000", want: 0xC000}, + + // Binary + {name: "binary 8bit", expr: "!11111111", want: 255}, + {name: "binary zero", expr: "!00000000", want: 0}, + {name: "binary mixed", expr: "!10101010", want: 0xAA}, + + // Addition + {name: "add decimal", expr: "10+20", want: 30}, + {name: "add hex", expr: "$10+$20", want: 0x30}, + {name: "add binary", expr: "!1010+!0101", want: 15}, + {name: "add multiple", expr: "10+20+30", want: 60}, + + // Subtraction + {name: "subtract", expr: "100-25", want: 75}, + {name: "subtract chain", expr: "100-20-10", want: 70}, + + // Multiplication + {name: "multiply", expr: "10*5", want: 50}, + {name: "multiply chain", expr: "2*3*4", want: 24}, + + // Division + {name: "divide", expr: "100/5", want: 20}, + {name: "divide chain", expr: "100/5/2", want: 10}, + + // Binary OR + {name: "binary or", expr: "$F0|$0F", want: 0xFF}, + {name: "binary or multiple", expr: "1|2|4", want: 7}, + + // Binary AND + {name: "binary and", expr: "$FF&$0F", want: 0x0F}, + {name: "binary and multiple", expr: "255&127&63", want: 63}, + + // Mixed operators + {name: "mixed math", expr: "10+20*2", want: 60}, // left-to-right: (10+20)*2 + {name: "mixed bitwise", expr: "255&$F0|$0F", want: 0xFF}, + + // Constants + {name: "constant", expr: "MAXVAL", lookup: constLookup, want: 255}, + {name: "constant add", expr: "BASE+OFFSET", lookup: constLookup, want: 110}, + {name: "constant unknown", expr: "UNKNOWN", lookup: constLookup, wantErr: true}, + {name: "constant no lookup", expr: "MAXVAL", wantErr: true}, + + // Spaces + {name: "spaces around ops", expr: "10 + 20", want: 30}, + {name: "spaces everywhere", expr: " 100 - 20 ", want: 80}, + + // Errors + {name: "empty", expr: "", wantErr: true}, + {name: "invalid hex", expr: "$ZZ", wantErr: true}, + {name: "invalid binary", expr: "!22", wantErr: true}, + {name: "division by zero", expr: "10/0", wantErr: true}, + {name: "trailing operator", expr: "10+", wantErr: true}, + {name: "leading operator", expr: "+10", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := EvaluateExpression(tt.expr, tt.lookup) + if (err != nil) != tt.wantErr { + t.Errorf("EvaluateExpression() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("EvaluateExpression(%q) = %v, want %v", tt.expr, got, tt.want) + } + }) + } +}