diff --git a/internal/compiler/funchandler.go b/internal/compiler/funchandler.go index 6920775..70a2f23 100644 --- a/internal/compiler/funchandler.go +++ b/internal/compiler/funchandler.go @@ -5,6 +5,7 @@ import ( "strings" "c65gm/internal/preproc" + "c65gm/internal/utils" ) // ParamDirection represents parameter passing direction @@ -499,11 +500,11 @@ func (fh *FunctionHandler) processConstValue(value uint16, param *FuncParam, fun return nil } -// parseImplicitDecl parses {BYTE varname} or {WORD varname} and adds to symbol table +// parseImplicitDecl parses {BYTE varname} or {WORD varname} or {BYTE varname @ address} and adds to symbol table func (fh *FunctionHandler) parseImplicitDecl(decl string, funcName string) error { parts := strings.Fields(decl) - if len(parts) != 2 { - return fmt.Errorf("implicit declaration must be 'BYTE name' or 'WORD name', got: %q", decl) + if len(parts) != 2 && len(parts) != 4 { + return fmt.Errorf("implicit declaration must be 'TYPE name' or 'TYPE name @ addr', got: %q", decl) } typeStr := strings.ToUpper(parts[0]) @@ -519,8 +520,39 @@ func (fh *FunctionHandler) parseImplicitDecl(decl string, funcName string) error return fmt.Errorf("implicit declaration type must be BYTE or WORD, got: %s", typeStr) } - // Add variable to symbol table with function scope - return fh.symTable.AddVar(varName, funcName, kind, 0) + if len(parts) == 2 { + // Simple: BYTE name or WORD name + return fh.symTable.AddVar(varName, funcName, kind, 0) + } + + // Extended: BYTE name @ address or WORD name @ address + operator := parts[2] + addrStr := parts[3] + + if operator != "@" { + return fmt.Errorf("expected '@' operator, got: %q", operator) + } + + // Create constant lookup function for address evaluation + constLookup := func(name string) (int64, bool) { + sym := fh.symTable.Lookup(name, []string{funcName}) + if sym != nil && sym.IsConst() { + return int64(sym.Value), true + } + return 0, false + } + + // Parse address (supports $hex and decimal) using EvaluateExpression + addr, err := utils.EvaluateExpression(addrStr, constLookup) + if err != nil { + return fmt.Errorf("invalid address %q: %w", addrStr, err) + } + + if addr < 0 || addr > 0xFFFF { + return fmt.Errorf("absolute address $%X out of range", addr) + } + + return fh.symTable.AddAbsolute(varName, funcName, kind, uint16(addr)) } // EndFunction pops all functions from the stack (called by FEND) diff --git a/internal/compiler/funchandler_test.go b/internal/compiler/funchandler_test.go index 74f4a0a..03d11ac 100644 --- a/internal/compiler/funchandler_test.go +++ b/internal/compiler/funchandler_test.go @@ -686,3 +686,346 @@ func TestCurrentFunction(t *testing.T) { t.Errorf("expected '', got %q", fh.CurrentFunction()) } } + +func TestHandleFuncDecl_AbsoluteVariables(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + _, err := fh.HandleFuncDecl(makeLine("FUNC test_abs ( {BYTE param1 @ $fa} {WORD param2 @ $fb} )")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + // Check that variables were declared as absolute + symByte := st.Lookup("param1", []string{"test_abs"}) + if symByte == nil { + t.Fatal("parameter 'param1' not declared") + } + if !symByte.IsByte() { + t.Error("parameter 'param1' should be byte") + } + if !symByte.IsAbsolute() { + t.Error("parameter 'param1' should be absolute") + } + if symByte.AbsAddr != 0xfa { + t.Errorf("parameter 'param1' address = $%02x, want $fa", symByte.AbsAddr) + } + if !symByte.IsZeroPage() { + t.Error("parameter 'param1' should be in zero page") + } + + symWord := st.Lookup("param2", []string{"test_abs"}) + if symWord == nil { + t.Fatal("parameter 'param2' not declared") + } + if !symWord.IsWord() { + t.Error("parameter 'param2' should be word") + } + if !symWord.IsAbsolute() { + t.Error("parameter 'param2' should be absolute") + } + if symWord.AbsAddr != 0xfb { + t.Errorf("parameter 'param2' address = $%02x, want $fb", symWord.AbsAddr) + } + if !symWord.IsZeroPage() { + t.Error("parameter 'param2' should be in zero page") + } + + // Check function declaration + funcDecl := fh.findFunc("test_abs") + if funcDecl == nil { + t.Fatal("function not found") + } + if len(funcDecl.Params) != 2 { + t.Fatalf("expected 2 params, got %d", len(funcDecl.Params)) + } +} + +func TestHandleFuncDecl_AbsoluteWithDirections(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + _, err := fh.HandleFuncDecl(makeLine("FUNC test_abs_dir ( in:{BYTE input @ $fa} out:{BYTE output @ $fb} io:{WORD data @ $fc} )")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + funcDecl := fh.findFunc("test_abs_dir") + if funcDecl == nil { + t.Fatal("function not found") + } + + if len(funcDecl.Params) != 3 { + t.Fatalf("expected 3 params, got %d", len(funcDecl.Params)) + } + + if funcDecl.Params[0].Direction != DirIn { + t.Error("param 0 should be DirIn") + } + if funcDecl.Params[1].Direction != DirOut { + t.Error("param 1 should be DirOut") + } + if funcDecl.Params[2].Direction != (DirIn | DirOut) { + t.Error("param 2 should be DirIn|DirOut") + } + + // Verify absolute addresses + if funcDecl.Params[0].Symbol.AbsAddr != 0xfa { + t.Errorf("param 0 address = $%02x, want $fa", funcDecl.Params[0].Symbol.AbsAddr) + } + if funcDecl.Params[1].Symbol.AbsAddr != 0xfb { + t.Errorf("param 1 address = $%02x, want $fb", funcDecl.Params[1].Symbol.AbsAddr) + } + if funcDecl.Params[2].Symbol.AbsAddr != 0xfc { + t.Errorf("param 2 address = $%02x, want $fc", funcDecl.Params[2].Symbol.AbsAddr) + } +} + +func TestHandleFuncDecl_AbsoluteMixedParams(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Mix absolute and regular params + _, err := fh.HandleFuncDecl(makeLine("FUNC test_mixed ( {BYTE abs_param @ $fa} {WORD reg_param} )")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + symAbs := st.Lookup("abs_param", []string{"test_mixed"}) + if symAbs == nil { + t.Fatal("parameter 'abs_param' not declared") + } + if !symAbs.IsAbsolute() { + t.Error("parameter 'abs_param' should be absolute") + } + + symReg := st.Lookup("reg_param", []string{"test_mixed"}) + if symReg == nil { + t.Fatal("parameter 'reg_param' not declared") + } + if symReg.IsAbsolute() { + t.Error("parameter 'reg_param' should not be absolute") + } +} + +func TestHandleFuncDecl_AbsoluteErrors(t *testing.T) { + tests := []struct { + name string + line string + wantErr string + }{ + { + name: "invalid operator", + line: "FUNC test ( {BYTE x = $fa} )", + wantErr: "expected '@' operator", + }, + { + name: "invalid address", + line: "FUNC test ( {BYTE x @ invalid} )", + wantErr: "invalid address", + }, + { + name: "address out of range", + line: "FUNC test ( {BYTE x @ $10000} )", + wantErr: "out of range", + }, + { + name: "negative address", + line: "FUNC test ( {BYTE x @ -1} )", + wantErr: "invalid address", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + _, err := fh.HandleFuncDecl(makeLine(tt.line)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestHandleFuncCall_AbsoluteParams(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function with absolute params + fh.HandleFuncDecl(makeLine("FUNC test_abs ( {BYTE param_a @ $fa} {WORD param_b @ $fb} )")) + + // Declare caller variables + st.AddVar("var_a", "", KindByte, 0) + st.AddVar("var_b", "", KindWord, 0) + + asm, err := fh.HandleFuncCall(makeLine("CALL test_abs ( var_a var_b )")) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Check generated assembly uses correct names + expectedLines := []string{ + " lda var_a", + " sta test_abs_param_a", + " lda var_b", + " sta test_abs_param_b", + " lda var_b+1", + " sta test_abs_param_b+1", + " jsr test_abs", + } + + if len(asm) != len(expectedLines) { + t.Fatalf("expected %d asm lines, got %d", len(expectedLines), len(asm)) + } + + for i, expected := range expectedLines { + if asm[i] != expected { + t.Errorf("asm[%d] = %q, want %q", i, asm[i], expected) + } + } +} + +func TestHandleFuncCall_AbsoluteOutParams(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function with absolute out param + fh.HandleFuncDecl(makeLine("FUNC get_result ( out:{BYTE result @ $fa} )")) + + // Declare caller variable + st.AddVar("output", "", KindByte, 0) + + asm, err := fh.HandleFuncCall(makeLine("CALL get_result ( output )")) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Should have JSR and OUT assignment + found_jsr := false + found_out := false + for _, line := range asm { + if strings.Contains(line, "jsr get_result") { + found_jsr = true + } + if strings.Contains(line, "lda get_result_result") { + found_out = true + } + } + + if !found_jsr { + t.Error("missing jsr instruction") + } + if !found_out { + t.Error("missing out assignment") + } +} + +func TestParseImplicitDecl_Absolute(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + tests := []struct { + name string + decl string + funcName string + wantErr bool + checkFn func(*testing.T, *SymbolTable) + }{ + { + name: "byte at hex address", + decl: "BYTE x @ $fa", + funcName: "test", + wantErr: false, + checkFn: func(t *testing.T, st *SymbolTable) { + sym := st.Lookup("x", []string{"test"}) + if sym == nil { + t.Fatal("symbol not found") + } + if !sym.IsAbsolute() || sym.AbsAddr != 0xfa { + t.Errorf("expected absolute at $fa, got absolute=%v addr=$%02x", sym.IsAbsolute(), sym.AbsAddr) + } + }, + }, + { + name: "word at decimal address", + decl: "WORD ptr @ 251", + funcName: "test", + wantErr: false, + checkFn: func(t *testing.T, st *SymbolTable) { + sym := st.Lookup("ptr", []string{"test"}) + if sym == nil { + t.Fatal("symbol not found") + } + if !sym.IsAbsolute() || sym.AbsAddr != 251 { + t.Errorf("expected absolute at 251, got absolute=%v addr=%d", sym.IsAbsolute(), sym.AbsAddr) + } + }, + }, + { + name: "invalid operator", + decl: "BYTE x = $fa", + funcName: "test", + wantErr: true, + }, + { + name: "too few parts", + decl: "BYTE x @", + funcName: "test", + wantErr: true, + }, + { + name: "too many parts", + decl: "BYTE x @ $fa extra", + funcName: "test", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear symbol table for each test + st = NewSymbolTable() + fh.symTable = st + + err := fh.parseImplicitDecl(tt.decl, tt.funcName) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.checkFn != nil { + tt.checkFn(t, st) + } + }) + } +}