diff --git a/internal/compiler/labelstack.go b/internal/compiler/labelstack.go new file mode 100644 index 0000000..7873157 --- /dev/null +++ b/internal/compiler/labelstack.go @@ -0,0 +1,48 @@ +package compiler + +import "fmt" + +type LabelStack struct { + stack []string + counter int + prefix string +} + +func NewLabelStack(prefix string) *LabelStack { + return &LabelStack{ + stack: make([]string, 0), + counter: 0, + prefix: prefix, + } +} + +func (ls *LabelStack) Push() string { + ls.counter++ + label := fmt.Sprintf("%s%d", ls.prefix, ls.counter) + ls.stack = append(ls.stack, label) + return label +} + +func (ls *LabelStack) Peek() (string, error) { + if len(ls.stack) == 0 { + return "", fmt.Errorf("stack underflow: %s stack is empty", ls.prefix) + } + return ls.stack[len(ls.stack)-1], nil +} + +func (ls *LabelStack) Pop() (string, error) { + if len(ls.stack) == 0 { + return "", fmt.Errorf("stack underflow: %s stack is empty", ls.prefix) + } + label := ls.stack[len(ls.stack)-1] + ls.stack = ls.stack[:len(ls.stack)-1] + return label, nil +} + +func (ls *LabelStack) IsEmpty() bool { + return len(ls.stack) == 0 +} + +func (ls *LabelStack) Size() int { + return len(ls.stack) +} diff --git a/internal/compiler/labelstack_test.go b/internal/compiler/labelstack_test.go new file mode 100644 index 0000000..d34bf98 --- /dev/null +++ b/internal/compiler/labelstack_test.go @@ -0,0 +1,232 @@ +package compiler + +import "testing" + +func TestLabelStack_PushCreatesUniqueLabels(t *testing.T) { + stack := NewLabelStack("test") + + lbl1 := stack.Push() + lbl2 := stack.Push() + lbl3 := stack.Push() + + if lbl1 != "test1" { + t.Errorf("expected test1, got %s", lbl1) + } + if lbl2 != "test2" { + t.Errorf("expected test2, got %s", lbl2) + } + if lbl3 != "test3" { + t.Errorf("expected test3, got %s", lbl3) + } +} + +func TestLabelStack_PeekDoesNotRemove(t *testing.T) { + stack := NewLabelStack("peek") + + stack.Push() + lbl := stack.Push() + + peeked, err := stack.Peek() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if peeked != lbl { + t.Errorf("expected %s, got %s", lbl, peeked) + } + + if stack.Size() != 2 { + t.Errorf("expected size 2 after peek, got %d", stack.Size()) + } +} + +func TestLabelStack_PopRemoves(t *testing.T) { + stack := NewLabelStack("pop") + + stack.Push() + lbl := stack.Push() + + popped, err := stack.Pop() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if popped != lbl { + t.Errorf("expected %s, got %s", lbl, popped) + } + + if stack.Size() != 1 { + t.Errorf("expected size 1 after pop, got %d", stack.Size()) + } +} + +func TestLabelStack_PeekEmptyReturnsError(t *testing.T) { + stack := NewLabelStack("empty") + + _, err := stack.Peek() + if err == nil { + t.Error("expected error on peek of empty stack") + } +} + +func TestLabelStack_PopEmptyReturnsError(t *testing.T) { + stack := NewLabelStack("empty") + + _, err := stack.Pop() + if err == nil { + t.Error("expected error on pop of empty stack") + } +} + +func TestLabelStack_IsEmpty(t *testing.T) { + stack := NewLabelStack("check") + + if !stack.IsEmpty() { + t.Error("new stack should be empty") + } + + stack.Push() + if stack.IsEmpty() { + t.Error("stack with item should not be empty") + } + + stack.Pop() + if !stack.IsEmpty() { + t.Error("stack after pop should be empty") + } +} + +func TestLabelStack_CounterNeverResets(t *testing.T) { + stack := NewLabelStack("counter") + + lbl1 := stack.Push() + lbl2 := stack.Push() + stack.Pop() + stack.Pop() + lbl3 := stack.Push() + + if lbl1 != "counter1" || lbl2 != "counter2" || lbl3 != "counter3" { + t.Errorf("counter reset detected: %s, %s, %s", lbl1, lbl2, lbl3) + } +} + +func TestLabelStack_WhileWendPattern(t *testing.T) { + whileStack := NewLabelStack("whilelbl") + wendStack := NewLabelStack("wendlbl") + + // WHILE + whileLbl := whileStack.Push() + wendLbl := wendStack.Push() + + if whileLbl != "whilelbl1" { + t.Errorf("expected whilelbl1, got %s", whileLbl) + } + if wendLbl != "wendlbl1" { + t.Errorf("expected wendlbl1, got %s", wendLbl) + } + + // WEND + w, err := whileStack.Pop() + if err != nil || w != whileLbl { + t.Errorf("failed to pop while label") + } + + wd, err := wendStack.Pop() + if err != nil || wd != wendLbl { + t.Errorf("failed to pop wend label") + } +} + +func TestLabelStack_IfElseEndifPattern(t *testing.T) { + ifStack := NewLabelStack("iflbl") + + // IF + ifLbl1 := ifStack.Push() + + // ELSE - peek the if label, then push new one + peeked, err := ifStack.Peek() + if err != nil || peeked != ifLbl1 { + t.Errorf("failed to peek if label") + } + + // Commit (pop) the first label + popped, err := ifStack.Pop() + if err != nil || popped != ifLbl1 { + t.Errorf("failed to pop first if label") + } + + // Push new label for ENDIF + ifLbl2 := ifStack.Push() + + // ENDIF + endifLbl, err := ifStack.Pop() + if err != nil || endifLbl != ifLbl2 { + t.Errorf("failed to pop endif label") + } + + if ifLbl1 != "iflbl1" || ifLbl2 != "iflbl2" { + t.Errorf("unexpected label names: %s, %s", ifLbl1, ifLbl2) + } +} + +func TestLabelStack_NestedIfPattern(t *testing.T) { + ifStack := NewLabelStack("iflbl") + + // Outer IF + outerIf := ifStack.Push() + + // Inner IF + innerIf := ifStack.Push() + + // Inner ENDIF + innerEnd, _ := ifStack.Pop() + if innerEnd != innerIf { + t.Errorf("inner endif mismatch") + } + + // Outer ENDIF + outerEnd, _ := ifStack.Pop() + if outerEnd != outerIf { + t.Errorf("outer endif mismatch") + } + + if outerIf != "iflbl1" || innerIf != "iflbl2" { + t.Errorf("nested labels incorrect: %s, %s", outerIf, innerIf) + } +} + +func TestLabelStack_BreakPattern(t *testing.T) { + wendStack := NewLabelStack("wendlbl") + + // WHILE + wendLbl := wendStack.Push() + + // BREAK - needs to peek wend label without popping + breakTarget, err := wendStack.Peek() + if err != nil || breakTarget != wendLbl { + t.Errorf("failed to peek for break") + } + + // Stack should still have the label + if wendStack.Size() != 1 { + t.Error("peek should not modify stack") + } + + // WEND - now pop it + wend, _ := wendStack.Pop() + if wend != wendLbl { + t.Errorf("wend label mismatch") + } +} + +func TestLabelStack_MultipleSeparateStacks(t *testing.T) { + whileStack := NewLabelStack("whilelbl") + ifStack := NewLabelStack("iflbl") + generalStack := NewLabelStack("general") + + w := whileStack.Push() + i := ifStack.Push() + g := generalStack.Push() + + if w != "whilelbl1" || i != "iflbl1" || g != "general1" { + t.Errorf("separate stacks interfering: %s, %s, %s", w, i, g) + } +} diff --git a/internal/compiler/symboltable.go b/internal/compiler/symboltable.go new file mode 100644 index 0000000..d32877e --- /dev/null +++ b/internal/compiler/symboltable.go @@ -0,0 +1,386 @@ +package compiler + +import ( + "fmt" + "strings" +) + +// VarKind represents the data type/size of a variable +type VarKind uint8 + +const ( + KindByte VarKind = iota + KindWord + // Future: KindDWord, Kind24bit, etc +) + +// SymbolFlags represents properties of a symbol as a bitfield +type SymbolFlags uint16 + +const ( + FlagByte SymbolFlags = 1 << iota + FlagWord + FlagConst + FlagAbsolute + FlagZeroPage + FlagLabelRef +) + +// Symbol represents a variable or constant declaration +type Symbol struct { + Name string + Scope string // empty string = global, otherwise function name + Flags SymbolFlags + Value uint16 // init value or const value + AbsAddr uint16 // if FlagAbsolute set + LabelRef string // if FlagLabelRef set +} + +// Helper methods for Symbol +func (s *Symbol) Has(flag SymbolFlags) bool { + return s.Flags&flag != 0 +} + +func (s *Symbol) HasAll(flags SymbolFlags) bool { + return s.Flags&flags == flags +} + +func (s *Symbol) IsByte() bool { return s.Has(FlagByte) } +func (s *Symbol) IsWord() bool { return s.Has(FlagWord) } +func (s *Symbol) IsConst() bool { return s.Has(FlagConst) } +func (s *Symbol) IsAbsolute() bool { return s.Has(FlagAbsolute) } +func (s *Symbol) IsZeroPage() bool { return s.Has(FlagZeroPage) } +func (s *Symbol) IsZeroPagePointer() bool { return s.HasAll(FlagAbsolute | FlagZeroPage) } + +// FullName returns the fully qualified name (scope.name or just name) +func (s *Symbol) FullName() string { + if s.Scope == "" { + return s.Name + } + return s.Scope + "." + s.Name +} + +// SymbolTable manages variable and constant declarations +type SymbolTable struct { + symbols []*Symbol // insertion order + byFullName map[string]*Symbol // fullname -> symbol + byScope map[string]map[string]*Symbol // scope -> name -> symbol +} + +// NewSymbolTable creates a new symbol table +func NewSymbolTable() *SymbolTable { + return &SymbolTable{ + symbols: make([]*Symbol, 0), + byFullName: make(map[string]*Symbol), + byScope: make(map[string]map[string]*Symbol), + } +} + +// AddVar adds a regular variable (byte or word) +func (st *SymbolTable) AddVar(name, scope string, kind VarKind, initValue uint16) error { + var flags SymbolFlags + + switch kind { + case KindByte: + flags = FlagByte + if initValue > 255 { + return fmt.Errorf("byte variable %q init value %d out of range", name, initValue) + } + case KindWord: + flags = FlagWord + default: + return fmt.Errorf("unknown variable kind: %d", kind) + } + + return st.add(&Symbol{ + Name: name, + Scope: scope, + Flags: flags, + Value: initValue, + }) +} + +// AddConst adds a constant (byte or word) +func (st *SymbolTable) AddConst(name, scope string, kind VarKind, value uint16) error { + var flags SymbolFlags + + switch kind { + case KindByte: + flags = FlagByte | FlagConst + if value > 255 { + return fmt.Errorf("byte constant %q value %d out of range", name, value) + } + case KindWord: + flags = FlagWord | FlagConst + default: + return fmt.Errorf("unknown variable kind: %d", kind) + } + + return st.add(&Symbol{ + Name: name, + Scope: scope, + Flags: flags, + Value: value, + }) +} + +// AddAbsolute adds a variable at a fixed memory address +func (st *SymbolTable) AddAbsolute(name, scope string, kind VarKind, addr uint16) error { + if addr > 0xFFFF { + return fmt.Errorf("absolute address %d exceeds 16-bit range", addr) + } + + var flags SymbolFlags + + switch kind { + case KindByte: + flags = FlagByte | FlagAbsolute + // Zero page check for bytes + if addr < 0x100 { + flags |= FlagZeroPage + } + case KindWord: + flags = FlagWord | FlagAbsolute + // Zero page check for words (pointer must fit in ZP) + if addr < 0xFF { + flags |= FlagZeroPage + } + default: + return fmt.Errorf("unknown variable kind: %d", kind) + } + + return st.add(&Symbol{ + Name: name, + Scope: scope, + Flags: flags, + AbsAddr: addr, + }) +} + +// AddLabel adds a word variable that references a label +func (st *SymbolTable) AddLabel(name, scope string, labelRef string) error { + return st.add(&Symbol{ + Name: name, + Scope: scope, + Flags: FlagWord | FlagLabelRef, + LabelRef: labelRef, + }) +} + +// add is the internal method that actually adds a symbol +func (st *SymbolTable) add(sym *Symbol) error { + fullName := sym.FullName() + + // Check for redeclaration + if _, exists := st.byFullName[fullName]; exists { + return fmt.Errorf("symbol %q already declared", fullName) + } + + // Add to all indexes + st.symbols = append(st.symbols, sym) + st.byFullName[fullName] = sym + + // Add to scope index + if st.byScope[sym.Scope] == nil { + st.byScope[sym.Scope] = make(map[string]*Symbol) + } + st.byScope[sym.Scope][sym.Name] = sym + + return nil +} + +// Lookup finds a symbol by name, resolving scope +// Searches local scope first (if currentScopes provided), then global +func (st *SymbolTable) Lookup(name string, currentScopes []string) *Symbol { + // Try local scopes first (innermost to outermost) + for i := len(currentScopes) - 1; i >= 0; i-- { + scope := currentScopes[i] + if scopeMap, ok := st.byScope[scope]; ok { + if sym, ok := scopeMap[name]; ok { + return sym + } + } + } + + // Try global scope + if scopeMap, ok := st.byScope[""]; ok { + if sym, ok := scopeMap[name]; ok { + return sym + } + } + + return nil +} + +// Get retrieves a symbol by its full name +func (st *SymbolTable) Get(fullName string) *Symbol { + return st.byFullName[fullName] +} + +// Symbols returns all symbols in insertion order +func (st *SymbolTable) Symbols() []*Symbol { + return st.symbols +} + +// Count returns the number of symbols +func (st *SymbolTable) Count() int { + return len(st.symbols) +} + +// ExpandName resolves a local name to its full name using scope resolution +func (st *SymbolTable) ExpandName(name string, currentScopes []string) string { + sym := st.Lookup(name, currentScopes) + if sym != nil { + return sym.FullName() + } + return name +} + +// String representation for debugging +func (s *Symbol) String() string { + var parts []string + parts = append(parts, fmt.Sprintf("Name=%s", s.FullName())) + + if s.IsByte() { + parts = append(parts, "BYTE") + } else if s.IsWord() { + parts = append(parts, "WORD") + } + + if s.IsConst() { + parts = append(parts, fmt.Sprintf("CONST=%d", s.Value)) + } else if s.IsAbsolute() { + parts = append(parts, fmt.Sprintf("@$%04X", s.AbsAddr)) + if s.IsZeroPage() { + parts = append(parts, "ZP") + } + } else if s.Has(FlagLabelRef) { + parts = append(parts, fmt.Sprintf("->%s", s.LabelRef)) + } else if s.Value != 0 { + parts = append(parts, fmt.Sprintf("=%d", s.Value)) + } + + return strings.Join(parts, " ") +} + +// Code generation functions for ACME assembler syntax + +// GenerateConstants generates constant definitions (name = $value) +func GenerateConstants(st *SymbolTable) []string { + var lines []string + hasConsts := false + + for _, sym := range st.Symbols() { + if !sym.IsConst() { + continue + } + + hasConsts = true + var line string + + if sym.IsByte() { + // Byte constant with decimal comment + line = fmt.Sprintf("%s = $%02x\t; %d", sym.FullName(), sym.Value, sym.Value) + } else { + // Word constant + line = fmt.Sprintf("%s = $%04x", sym.FullName(), sym.Value) + } + + lines = append(lines, line) + } + + if hasConsts { + // Prepend header + result := []string{ + ";Constant values (from c65gm)", + "", + } + result = append(result, lines...) + result = append(result, "") // blank line after + return result + } + + return nil +} + +// GenerateAbsolutes generates absolute address assignments (name = $addr) +func GenerateAbsolutes(st *SymbolTable) []string { + var lines []string + hasAbsolutes := false + + for _, sym := range st.Symbols() { + if !sym.IsAbsolute() { + continue + } + + hasAbsolutes = true + var line string + + if sym.IsZeroPage() { + // Zero-page: 2 hex digits + line = fmt.Sprintf("%s = $%02x", sym.FullName(), sym.AbsAddr) + } else { + // Non-zero-page: 4 hex digits + line = fmt.Sprintf("%s = $%04x", sym.FullName(), sym.AbsAddr) + } + + lines = append(lines, line) + } + + if hasAbsolutes { + // Prepend header + result := []string{ + ";Absolute variable definitions (from c65gm)", + "", + } + result = append(result, lines...) + result = append(result, "") // blank line after + return result + } + + return nil +} + +// GenerateVariables generates variable declarations (name !8 $value) +func GenerateVariables(st *SymbolTable) []string { + var lines []string + hasVars := false + + for _, sym := range st.Symbols() { + // Skip constants and absolutes - they're handled separately + if sym.IsConst() || sym.IsAbsolute() { + continue + } + + hasVars = true + var line string + + if sym.IsByte() { + // Byte variable with decimal comment + line = fmt.Sprintf("%s\t!8 $%02x\t; %d", sym.FullName(), sym.Value&0xFF, sym.Value&0xFF) + } else if sym.Has(FlagLabelRef) { + // Word with label reference + line = fmt.Sprintf("%s\t!8 <%s, >%s", sym.FullName(), sym.LabelRef, sym.LabelRef) + } else { + // Word variable (split into low byte, high byte) + lo := sym.Value & 0xFF + hi := (sym.Value >> 8) & 0xFF + line = fmt.Sprintf("%s\t!8 $%02x, $%02x", sym.FullName(), lo, hi) + } + + lines = append(lines, line) + } + + if hasVars { + // Prepend header + result := []string{ + ";Variables (from c65gm)", + "", + } + result = append(result, lines...) + result = append(result, "") // blank line after + return result + } + + return nil +} diff --git a/internal/compiler/symboltable_test.go b/internal/compiler/symboltable_test.go new file mode 100644 index 0000000..03932f5 --- /dev/null +++ b/internal/compiler/symboltable_test.go @@ -0,0 +1,695 @@ +package compiler + +import ( + "strings" + "testing" +) + +func TestSymbolFlags(t *testing.T) { + tests := []struct { + name string + flags SymbolFlags + isByte bool + isWord bool + isConst bool + isAbs bool + isZP bool + isZPPtr bool + }{ + { + name: "byte variable", + flags: FlagByte, + isByte: true, + isWord: false, + isConst: false, + }, + { + name: "word variable", + flags: FlagWord, + isByte: false, + isWord: true, + isConst: false, + }, + { + name: "byte constant", + flags: FlagByte | FlagConst, + isByte: true, + isConst: true, + }, + { + name: "absolute zero-page", + flags: FlagWord | FlagAbsolute | FlagZeroPage, + isWord: true, + isAbs: true, + isZP: true, + isZPPtr: true, + }, + { + name: "absolute non-zero-page", + flags: FlagWord | FlagAbsolute, + isWord: true, + isAbs: true, + isZP: false, + isZPPtr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Symbol{Flags: tt.flags} + + if s.IsByte() != tt.isByte { + t.Errorf("IsByte() = %v, want %v", s.IsByte(), tt.isByte) + } + if s.IsWord() != tt.isWord { + t.Errorf("IsWord() = %v, want %v", s.IsWord(), tt.isWord) + } + if s.IsConst() != tt.isConst { + t.Errorf("IsConst() = %v, want %v", s.IsConst(), tt.isConst) + } + if s.IsAbsolute() != tt.isAbs { + t.Errorf("IsAbsolute() = %v, want %v", s.IsAbsolute(), tt.isAbs) + } + if s.IsZeroPage() != tt.isZP { + t.Errorf("IsZeroPage() = %v, want %v", s.IsZeroPage(), tt.isZP) + } + if s.IsZeroPagePointer() != tt.isZPPtr { + t.Errorf("IsZeroPagePointer() = %v, want %v", s.IsZeroPagePointer(), tt.isZPPtr) + } + }) + } +} + +func TestSymbolFullName(t *testing.T) { + tests := []struct { + name string + symName string + scope string + expected string + }{ + {"global", "counter", "", "counter"}, + {"local", "temp", "main", "main.temp"}, + {"nested", "var", "outer.inner", "outer.inner.var"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Symbol{Name: tt.symName, Scope: tt.scope} + if got := s.FullName(); got != tt.expected { + t.Errorf("FullName() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestAddVar(t *testing.T) { + st := NewSymbolTable() + + // Add byte var + err := st.AddVar("counter", "", KindByte, 0) + if err != nil { + t.Fatalf("AddVar() error = %v", err) + } + + sym := st.Get("counter") + if sym == nil { + t.Fatal("symbol not found") + } + if !sym.IsByte() { + t.Error("expected byte variable") + } + if sym.IsConst() { + t.Error("should not be const") + } + + // Add word var + err = st.AddVar("ptr", "", KindWord, 0x1234) + if err != nil { + t.Fatalf("AddVar() error = %v", err) + } + + sym = st.Get("ptr") + if !sym.IsWord() { + t.Error("expected word variable") + } + if sym.Value != 0x1234 { + t.Errorf("Value = %d, want %d", sym.Value, 0x1234) + } + + // Test byte value range check + err = st.AddVar("bad", "", KindByte, 256) + if err == nil { + t.Error("expected error for byte value > 255") + } +} + +func TestAddConst(t *testing.T) { + st := NewSymbolTable() + + err := st.AddConst("MAX", "", KindByte, 255) + if err != nil { + t.Fatalf("AddConst() error = %v", err) + } + + sym := st.Get("MAX") + if sym == nil { + t.Fatal("symbol not found") + } + if !sym.IsConst() { + t.Error("expected constant") + } + if !sym.IsByte() { + t.Error("expected byte constant") + } + if sym.Value != 255 { + t.Errorf("Value = %d, want 255", sym.Value) + } + + // Test byte range check + err = st.AddConst("BAD", "", KindByte, 300) + if err == nil { + t.Error("expected error for byte const > 255") + } +} + +func TestAddAbsolute(t *testing.T) { + st := NewSymbolTable() + + // Zero-page byte + err := st.AddAbsolute("ZP_VAR", "", KindByte, 0x80) + if err != nil { + t.Fatalf("AddAbsolute() error = %v", err) + } + + sym := st.Get("ZP_VAR") + if !sym.IsAbsolute() { + t.Error("expected absolute") + } + if !sym.IsZeroPage() { + t.Error("expected zero-page flag for addr < $100") + } + if sym.AbsAddr != 0x80 { + t.Errorf("AbsAddr = $%04X, want $0080", sym.AbsAddr) + } + + // Zero-page word pointer + err = st.AddAbsolute("ZP_PTR", "", KindWord, 0xFE) + if err != nil { + t.Fatalf("AddAbsolute() error = %v", err) + } + + sym = st.Get("ZP_PTR") + if !sym.IsZeroPagePointer() { + t.Error("expected zero-page pointer (word addr < $FF)") + } + + // Non-zero-page + err = st.AddAbsolute("VIC", "", KindWord, 0xD000) + if err != nil { + t.Fatalf("AddAbsolute() error = %v", err) + } + + sym = st.Get("VIC") + if !sym.IsAbsolute() { + t.Error("expected absolute") + } + if sym.IsZeroPage() { + t.Error("should not have zero-page flag") + } + if sym.AbsAddr != 0xD000 { + t.Errorf("AbsAddr = $%04X, want $D000", sym.AbsAddr) + } +} + +func TestAddLabel(t *testing.T) { + st := NewSymbolTable() + + err := st.AddLabel("handler", "", "irq_vector") + if err != nil { + t.Fatalf("AddLabel() error = %v", err) + } + + sym := st.Get("handler") + if sym == nil { + t.Fatal("symbol not found") + } + if !sym.IsWord() { + t.Error("label ref should be word") + } + if !sym.Has(FlagLabelRef) { + t.Error("expected FlagLabelRef") + } + if sym.LabelRef != "irq_vector" { + t.Errorf("LabelRef = %q, want %q", sym.LabelRef, "irq_vector") + } +} + +func TestRedeclaration(t *testing.T) { + st := NewSymbolTable() + + err := st.AddVar("test", "", KindByte, 0) + if err != nil { + t.Fatalf("first AddVar() error = %v", err) + } + + // Attempt redeclaration + err = st.AddVar("test", "", KindByte, 0) + if err == nil { + t.Error("expected error on redeclaration") + } + + // Different scope should be OK + err = st.AddVar("test", "main", KindByte, 0) + if err != nil { + t.Errorf("AddVar() with different scope error = %v", err) + } +} + +func TestLookup(t *testing.T) { + st := NewSymbolTable() + + // Global variable + st.AddVar("global", "", KindByte, 0) + + // Local in main + st.AddVar("local", "main", KindByte, 0) + + // Local in nested function + st.AddVar("inner", "main.helper", KindByte, 0) + + tests := []struct { + name string + searchName string + currentScopes []string + expectFound bool + expectFull string + }{ + { + name: "find global from empty scope", + searchName: "global", + currentScopes: []string{}, + expectFound: true, + expectFull: "global", + }, + { + name: "find global from main scope", + searchName: "global", + currentScopes: []string{"main"}, + expectFound: true, + expectFull: "global", + }, + { + name: "find local from same scope", + searchName: "local", + currentScopes: []string{"main"}, + expectFound: true, + expectFull: "main.local", + }, + { + name: "shadow global with local", + searchName: "local", + currentScopes: []string{"main"}, + expectFound: true, + expectFull: "main.local", + }, + { + name: "find inner from nested scope", + searchName: "inner", + currentScopes: []string{"main", "main.helper"}, + expectFound: true, + expectFull: "main.helper.inner", + }, + { + name: "not found", + searchName: "notexist", + currentScopes: []string{}, + expectFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sym := st.Lookup(tt.searchName, tt.currentScopes) + + if tt.expectFound { + if sym == nil { + t.Fatal("symbol not found") + } + if got := sym.FullName(); got != tt.expectFull { + t.Errorf("FullName() = %q, want %q", got, tt.expectFull) + } + } else { + if sym != nil { + t.Errorf("expected not found, got %v", sym) + } + } + }) + } +} + +func TestExpandName(t *testing.T) { + st := NewSymbolTable() + + st.AddVar("global", "", KindByte, 0) + st.AddVar("local", "main", KindByte, 0) + + tests := []struct { + name string + searchName string + currentScopes []string + expected string + }{ + {"expand local", "local", []string{"main"}, "main.local"}, + {"expand global", "global", []string{"main"}, "global"}, + {"no expansion needed", "notfound", []string{}, "notfound"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := st.ExpandName(tt.searchName, tt.currentScopes) + if got != tt.expected { + t.Errorf("ExpandName() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestInsertionOrder(t *testing.T) { + st := NewSymbolTable() + + names := []string{"first", "second", "third"} + for _, name := range names { + st.AddVar(name, "", KindByte, 0) + } + + symbols := st.Symbols() + if len(symbols) != len(names) { + t.Fatalf("Count = %d, want %d", len(symbols), len(names)) + } + + for i, name := range names { + if symbols[i].Name != name { + t.Errorf("symbols[%d].Name = %q, want %q", i, symbols[i].Name, name) + } + } +} + +func TestCount(t *testing.T) { + st := NewSymbolTable() + + if st.Count() != 0 { + t.Errorf("initial Count() = %d, want 0", st.Count()) + } + + st.AddVar("a", "", KindByte, 0) + st.AddVar("b", "", KindByte, 0) + + if st.Count() != 2 { + t.Errorf("Count() = %d, want 2", st.Count()) + } +} + +func TestSymbolString(t *testing.T) { + tests := []struct { + name string + setup func(*SymbolTable) *Symbol + contains []string + }{ + { + name: "byte variable", + setup: func(st *SymbolTable) *Symbol { + st.AddVar("test", "", KindByte, 0) + return st.Get("test") + }, + contains: []string{"Name=test", "BYTE"}, + }, + { + name: "word constant", + setup: func(st *SymbolTable) *Symbol { + st.AddConst("MAX", "", KindWord, 65535) + return st.Get("MAX") + }, + contains: []string{"Name=MAX", "WORD", "CONST=65535"}, + }, + { + name: "zero-page pointer", + setup: func(st *SymbolTable) *Symbol { + st.AddAbsolute("ptr", "", KindWord, 0x80) + return st.Get("ptr") + }, + contains: []string{"Name=ptr", "WORD", "@$0080", "ZP"}, + }, + { + name: "label reference", + setup: func(st *SymbolTable) *Symbol { + st.AddLabel("handler", "", "irq") + return st.Get("handler") + }, + contains: []string{"Name=handler", "->irq"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st := NewSymbolTable() + sym := tt.setup(st) + + str := sym.String() + for _, want := range tt.contains { + if !containsSubstring(str, want) { + t.Errorf("String() = %q, missing %q", str, want) + } + } + }) + } +} + +func containsSubstring(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || len(s) > len(substr) && containsAt(s, substr)) +} + +func containsAt(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// Code generation tests + +func TestGenerateConstants(t *testing.T) { + st := NewSymbolTable() + + // Add some constants + st.AddConst("MAX", "", KindByte, 255) + st.AddConst("SIZE", "", KindWord, 0x1234) + st.AddVar("notconst", "", KindByte, 0) // should be skipped + + lines := GenerateConstants(st) + + if len(lines) == 0 { + t.Fatal("expected output lines") + } + + // Check header + if lines[0] != ";Constant values (from c65gm)" { + t.Errorf("expected header comment, got %q", lines[0]) + } + + // Find constant definitions + output := strings.Join(lines, "\n") + + if !strings.Contains(output, "MAX = $ff") { + t.Error("expected byte constant with lowercase hex") + } + if !strings.Contains(output, "; 255") { + t.Error("expected decimal comment for byte constant") + } + if !strings.Contains(output, "SIZE = $1234") { + t.Error("expected word constant") + } + if strings.Contains(output, "notconst") { + t.Error("regular variable should not appear in constants") + } +} + +func TestGenerateAbsolutes(t *testing.T) { + st := NewSymbolTable() + + // Zero-page + st.AddAbsolute("ZP_VAR", "", KindByte, 0x80) + st.AddAbsolute("ZP_PTR", "", KindWord, 0xFE) + + // Non-zero-page + st.AddAbsolute("VIC", "", KindWord, 0xD000) + + // Regular var (should be skipped) + st.AddVar("regular", "", KindByte, 0) + + lines := GenerateAbsolutes(st) + + if len(lines) == 0 { + t.Fatal("expected output lines") + } + + // Check header + if lines[0] != ";Absolute variable definitions (from c65gm)" { + t.Errorf("expected header comment, got %q", lines[0]) + } + + output := strings.Join(lines, "\n") + + // Zero-page should use 2 hex digits + if !strings.Contains(output, "ZP_VAR = $80") { + t.Error("expected zero-page with 2 hex digits") + } + if !strings.Contains(output, "ZP_PTR = $fe") { + t.Error("expected zero-page pointer with 2 hex digits") + } + + // Non-zero-page should use 4 hex digits + if !strings.Contains(output, "VIC = $d000") { + t.Error("expected non-zero-page with 4 hex digits") + } + + if strings.Contains(output, "regular") { + t.Error("regular variable should not appear in absolutes") + } +} + +func TestGenerateVariables(t *testing.T) { + st := NewSymbolTable() + + // Byte variable + st.AddVar("counter", "", KindByte, 42) + + // Word variable + st.AddVar("ptr", "", KindWord, 0x1234) + + // Label reference + st.AddLabel("handler", "", "irq_routine") + + // Const (should be skipped) + st.AddConst("SKIP", "", KindByte, 99) + + // Absolute (should be skipped) + st.AddAbsolute("SKIP2", "", KindByte, 0x80) + + lines := GenerateVariables(st) + + if len(lines) == 0 { + t.Fatal("expected output lines") + } + + // Check header + if lines[0] != ";Variables (from c65gm)" { + t.Errorf("expected header comment, got %q", lines[0]) + } + + output := strings.Join(lines, "\n") + + // Byte variable with decimal comment + if !strings.Contains(output, "counter\t!8 $2a") { + t.Error("expected byte variable declaration") + } + if !strings.Contains(output, "; 42") { + t.Error("expected decimal comment for byte") + } + + // Word variable (low, high) + if !strings.Contains(output, "ptr\t!8 $34, $12") { + t.Error("expected word variable as two bytes (lo, hi)") + } + + // Label reference + if !strings.Contains(output, "handler\t!8 irq_routine") { + t.Error("expected label reference with < and >") + } + + // Should not contain constants or absolutes + if strings.Contains(output, "SKIP") { + t.Error("constants should not appear in variables") + } + if strings.Contains(output, "SKIP2") { + t.Error("absolutes should not appear in variables") + } +} + +func TestGenerateEmpty(t *testing.T) { + st := NewSymbolTable() + + // Empty table + if lines := GenerateConstants(st); lines != nil { + t.Error("expected nil for empty constants") + } + if lines := GenerateAbsolutes(st); lines != nil { + t.Error("expected nil for empty absolutes") + } + if lines := GenerateVariables(st); lines != nil { + t.Error("expected nil for empty variables") + } + + // Only variables (no constants/absolutes) + st.AddVar("test", "", KindByte, 0) + + if lines := GenerateConstants(st); lines != nil { + t.Error("expected nil when no constants exist") + } + if lines := GenerateAbsolutes(st); lines != nil { + t.Error("expected nil when no absolutes exist") + } +} + +func TestGenerateScopedVariables(t *testing.T) { + st := NewSymbolTable() + + st.AddVar("global", "", KindByte, 0) + st.AddVar("local", "main", KindByte, 0) + st.AddVar("nested", "main.helper", KindByte, 0) + + lines := GenerateVariables(st) + output := strings.Join(lines, "\n") + + // Check full names are used + if !strings.Contains(output, "global\t!8") { + t.Error("expected global variable") + } + if !strings.Contains(output, "main.local\t!8") { + t.Error("expected scoped variable with full name") + } + if !strings.Contains(output, "main.helper.nested\t!8") { + t.Error("expected nested scoped variable with full name") + } +} + +func TestGenerateHexLowercase(t *testing.T) { + st := NewSymbolTable() + + st.AddConst("TEST", "", KindByte, 0xAB) + st.AddAbsolute("ADDR", "", KindWord, 0xDEAD) + st.AddVar("VAR", "", KindWord, 0xBEEF) + + constLines := GenerateConstants(st) + absLines := GenerateAbsolutes(st) + varLines := GenerateVariables(st) + + output := strings.Join(append(append(constLines, absLines...), varLines...), "\n") + + // Check all hex is lowercase + if strings.Contains(output, "$AB") || strings.Contains(output, "$DEAD") || strings.Contains(output, "$BEEF") { + t.Error("hex digits should be lowercase") + } + + if !strings.Contains(output, "$ab") { + t.Error("expected lowercase hex in constant") + } + if !strings.Contains(output, "$dead") { + t.Error("expected lowercase hex in absolute") + } + if !strings.Contains(output, "$be") && !strings.Contains(output, "$ef") { + t.Error("expected lowercase hex in variable") + } +}