diff --git a/internal/compiler/funchandler.go b/internal/compiler/funchandler.go new file mode 100644 index 0000000..ee254e5 --- /dev/null +++ b/internal/compiler/funchandler.go @@ -0,0 +1,657 @@ +package compiler + +import ( + "fmt" + "strings" + + "c65gm/internal/preproc" +) + +// ParamDirection represents parameter passing direction +type ParamDirection uint8 + +const ( + DirIn ParamDirection = 1 << iota + DirOut +) + +func (d ParamDirection) Has(flag ParamDirection) bool { + return d&flag != 0 +} + +// FuncParam represents a function parameter +type FuncParam struct { + Symbol *Symbol + Direction ParamDirection +} + +// FuncDecl represents a function declaration +type FuncDecl struct { + Name string + Params []*FuncParam +} + +// FunctionHandler manages function declarations and calls +type FunctionHandler struct { + functions []*FuncDecl + currentFuncs []string // stack of current function names (for nested scope) + symTable *SymbolTable + labelStack *LabelStack + constStrHandler *ConstantStringHandler + pragma *preproc.Pragma +} + +// NewFunctionHandler creates a new function handler +func NewFunctionHandler(st *SymbolTable, ls *LabelStack, csh *ConstantStringHandler, pragma *preproc.Pragma) *FunctionHandler { + return &FunctionHandler{ + functions: make([]*FuncDecl, 0), + currentFuncs: make([]string, 0), + symTable: st, + labelStack: ls, + constStrHandler: csh, + pragma: pragma, + } +} + +// HandleFuncDecl parses and processes a FUNC declaration +// Syntax: FUNC name ( param1 param2 ... ) +// Or: FUNC name (void function) +func (fh *FunctionHandler) HandleFuncDecl(line preproc.Line) ([]string, error) { + // Normalize parentheses and commas + text := fixIntuitiveFuncs(line.Text) + + params, err := parseParams(text) + if err != nil { + return nil, fmt.Errorf("%s:%d: %w", line.Filename, line.LineNo, err) + } + + if len(params) < 2 { + return nil, fmt.Errorf("%s:%d: FUNC: expected at least function name", line.Filename, line.LineNo) + } + + if strings.ToUpper(params[0]) != "FUNC" { + return nil, fmt.Errorf("%s:%d: not a FUNC declaration", line.Filename, line.LineNo) + } + + funcName := params[1] + + // Check for redeclaration + if fh.FuncExists(funcName) { + return nil, fmt.Errorf("%s:%d: function %q already declared", line.Filename, line.LineNo, funcName) + } + + // Push function name to current function stack early + // (so param declarations get correct scope) + fh.currentFuncs = append(fh.currentFuncs, funcName) + + // Parse parameters + var funcParams []*FuncParam + + if len(params) == 2 { + // Void function: FUNC name + // No parameters + } else if len(params) >= 5 { + // FUNC name ( param1 param2 ) + if params[2] != "(" || params[len(params)-1] != ")" { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC: expected parentheses around parameters", line.Filename, line.LineNo) + } + + // Extract params between ( and ) - need to handle {BYTE x} specially + rawParamTokens := params[3 : len(params)-1] + paramSpecs, err := buildComplexParams(rawParamTokens) + if err != nil { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC %s: %w", line.Filename, line.LineNo, funcName, err) + } + + for _, spec := range paramSpecs { + direction, varName, isImplicit, implicitDecl, err := parseParamSpec(spec) + if err != nil { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC %s: %w", line.Filename, line.LineNo, funcName, err) + } + + if isImplicit { + // Parse and add implicit variable declaration + // Format: {BYTE varname} or {WORD varname} + if err := fh.parseImplicitDecl(implicitDecl, funcName); err != nil { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC %s: implicit declaration: %w", line.Filename, line.LineNo, funcName, err) + } + } + + // Look up variable in symbol table + sym := fh.symTable.Lookup(varName, []string{funcName}) + if sym == nil { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC %s: parameter %q not declared", line.Filename, line.LineNo, funcName, varName) + } + + if sym.IsConst() { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC %s: parameter %q cannot be a constant", line.Filename, line.LineNo, funcName, varName) + } + + funcParams = append(funcParams, &FuncParam{ + Symbol: sym, + Direction: direction, + }) + } + } else { + fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1] + return nil, fmt.Errorf("%s:%d: FUNC: invalid syntax", line.Filename, line.LineNo) + } + + // Store function declaration + fh.functions = append(fh.functions, &FuncDecl{ + Name: funcName, + Params: funcParams, + }) + + // Generate assembler label + return []string{funcName}, nil +} + +// buildComplexParams handles parameter lists that may contain {BYTE x} style declarations +// Tokens like {BYTE x} are spread across multiple tokens and need to be reassembled +func buildComplexParams(tokens []string) ([]string, error) { + var result []string + var current string + inBraces := false + + for _, token := range tokens { + hasStart := strings.Contains(token, "{") + hasEnd := strings.Contains(token, "}") + + if !inBraces { + // Not currently in braces + if hasEnd && !hasStart { + return nil, fmt.Errorf("unexpected } without matching {") + } + if hasStart { + // Starting a brace block + inBraces = true + current = token + // Check if it also ends on same token + if hasEnd { + result = append(result, current) + current = "" + inBraces = false + } + } else { + // Regular param + result = append(result, token) + } + } else { + // Currently accumulating in braces + if hasStart { + return nil, fmt.Errorf("unexpected { while already in braces") + } + current += " " + token + if hasEnd { + result = append(result, current) + current = "" + inBraces = false + } + } + } + + if inBraces { + return nil, fmt.Errorf("unclosed { in parameter list") + } + + return result, nil +} + +// HandleFuncCall generates code for a function call +// Syntax: CALL funcname ( arg1 arg2 ... ) +// Or: funcname ( arg1 arg2 ... ) +func (fh *FunctionHandler) HandleFuncCall(line preproc.Line) ([]string, error) { + // Normalize parentheses and commas + text := fixIntuitiveFuncs(line.Text) + + params, err := parseParams(text) + if err != nil { + return nil, fmt.Errorf("%s:%d: %w", line.Filename, line.LineNo, err) + } + + if len(params) < 1 { + return nil, fmt.Errorf("%s:%d: CALL: empty line", line.Filename, line.LineNo) + } + + // Check if starts with CALL keyword + startsWithCall := strings.ToUpper(params[0]) == "CALL" + + funcNameIdx := 0 + if startsWithCall { + if len(params) < 2 { + return nil, fmt.Errorf("%s:%d: CALL: expected function name", line.Filename, line.LineNo) + } + funcNameIdx = 1 + } + + funcName := params[funcNameIdx] + + // Check if function exists + funcDecl := fh.findFunc(funcName) + if funcDecl == nil { + return nil, fmt.Errorf("%s:%d: function %q not declared", line.Filename, line.LineNo, funcName) + } + + // Parse call arguments + var callArgs []string + + if len(params) == funcNameIdx+1 { + // No arguments: funcname or CALL funcname + callArgs = []string{} + } else if len(params) >= funcNameIdx+4 { + // funcname ( arg1 arg2 ) or CALL funcname ( arg1 arg2 ) + if params[funcNameIdx+1] != "(" || params[len(params)-1] != ")" { + return nil, fmt.Errorf("%s:%d: CALL %s: expected parentheses around arguments", line.Filename, line.LineNo, funcName) + } + callArgs = params[funcNameIdx+2 : len(params)-1] + } else { + return nil, fmt.Errorf("%s:%d: CALL %s: invalid syntax", line.Filename, line.LineNo, funcName) + } + + // Check argument count matches + if len(callArgs) != len(funcDecl.Params) { + return nil, fmt.Errorf("%s:%d: CALL %s: expected %d arguments, got %d", + line.Filename, line.LineNo, funcName, len(funcDecl.Params), len(callArgs)) + } + + // Get pragma set for this line + pragmaSet := fh.pragma.GetPragmaSetByIndex(line.PragmaSetIndex) + + var asmLines []string + var inAssigns []string + var outAssigns []string + + // Process each argument + for i, arg := range callArgs { + param := funcDecl.Params[i] + + // Handle different argument types + if strings.HasPrefix(arg, "@") { + // Label reference: @labelname + if err := fh.processLabelArg(arg, param, funcName, line, &inAssigns); err != nil { + return nil, err + } + } else if strings.HasPrefix(arg, "\"") && strings.HasSuffix(arg, "\"") { + // String constant + if err := fh.processStringArg(arg, param, funcName, line, pragmaSet, &inAssigns); err != nil { + return nil, err + } + } else if sym := fh.symTable.Lookup(arg, fh.currentFuncs); sym != nil { + // Variable reference + if err := fh.processVarArg(sym, param, funcName, line, &inAssigns, &outAssigns); err != nil { + return nil, err + } + } else { + // Numeric constant + if err := fh.processConstArg(arg, param, funcName, line, &inAssigns); err != nil { + return nil, err + } + } + } + + // Generate final assembly + asmLines = append(asmLines, inAssigns...) + asmLines = append(asmLines, fmt.Sprintf(" jsr %s", funcName)) + asmLines = append(asmLines, outAssigns...) + + return asmLines, nil +} + +// processLabelArg handles @label arguments +func (fh *FunctionHandler) processLabelArg(arg string, param *FuncParam, funcName string, line preproc.Line, inAssigns *[]string) error { + labelName := arg[1:] // strip @ + + if param.Symbol.IsByte() { + return fmt.Errorf("%s:%d: CALL %s: cannot pass label to byte parameter", line.Filename, line.LineNo, funcName) + } + + if param.Direction.Has(DirOut) { + return fmt.Errorf("%s:%d: CALL %s: cannot pass label to out/io parameter", line.Filename, line.LineNo, funcName) + } + + *inAssigns = append(*inAssigns, + fmt.Sprintf(" lda #<%s", labelName), + fmt.Sprintf(" sta %s", param.Symbol.FullName()), + fmt.Sprintf(" lda #>%s", labelName), + fmt.Sprintf(" sta %s+1", param.Symbol.FullName()), + ) + + return nil +} + +// processStringArg handles "string" arguments +func (fh *FunctionHandler) processStringArg(arg string, param *FuncParam, funcName string, line preproc.Line, pragmaSet preproc.PragmaSet, inAssigns *[]string) error { + if param.Symbol.IsByte() { + return fmt.Errorf("%s:%d: CALL %s: cannot pass string to byte parameter", line.Filename, line.LineNo, funcName) + } + + if param.Direction.Has(DirOut) { + return fmt.Errorf("%s:%d: CALL %s: cannot pass string to out/io parameter", line.Filename, line.LineNo, funcName) + } + + // Generate label for string constant + labelName := fh.labelStack.Push() + fh.constStrHandler.AddConstStr(labelName, arg, true, pragmaSet) + + *inAssigns = append(*inAssigns, + fmt.Sprintf(" lda #<%s", labelName), + fmt.Sprintf(" sta %s", param.Symbol.FullName()), + fmt.Sprintf(" lda #>%s", labelName), + fmt.Sprintf(" sta %s+1", param.Symbol.FullName()), + ) + + return nil +} + +// processVarArg handles variable arguments +func (fh *FunctionHandler) processVarArg(sym *Symbol, param *FuncParam, funcName string, line preproc.Line, inAssigns, outAssigns *[]string) error { + // Check type compatibility + if (sym.IsByte() && param.Symbol.IsWord()) || (sym.IsWord() && param.Symbol.IsByte()) { + return fmt.Errorf("%s:%d: CALL %s: type mismatch for parameter %s", line.Filename, line.LineNo, funcName, param.Symbol.Name) + } + + if sym.IsConst() { + return fmt.Errorf("%s:%d: CALL %s: cannot pass constant to function", line.Filename, line.LineNo, funcName) + } + + // Generate IN assignments + if param.Direction.Has(DirIn) { + *inAssigns = append(*inAssigns, + fmt.Sprintf(" lda %s", sym.FullName()), + fmt.Sprintf(" sta %s", param.Symbol.FullName()), + ) + if sym.IsWord() { + *inAssigns = append(*inAssigns, + fmt.Sprintf(" lda %s+1", sym.FullName()), + fmt.Sprintf(" sta %s+1", param.Symbol.FullName()), + ) + } + } + + // Generate OUT assignments + if param.Direction.Has(DirOut) { + *outAssigns = append(*outAssigns, + fmt.Sprintf(" lda %s", param.Symbol.FullName()), + fmt.Sprintf(" sta %s", sym.FullName()), + ) + if sym.IsWord() { + *outAssigns = append(*outAssigns, + fmt.Sprintf(" lda %s+1", param.Symbol.FullName()), + fmt.Sprintf(" sta %s+1", sym.FullName()), + ) + } + } + + return nil +} + +// processConstArg handles numeric constant arguments +func (fh *FunctionHandler) processConstArg(arg string, param *FuncParam, funcName string, line preproc.Line, inAssigns *[]string) error { + if param.Direction.Has(DirOut) { + return fmt.Errorf("%s:%d: CALL %s: cannot pass constant to out/io parameter", line.Filename, line.LineNo, funcName) + } + + // Parse numeric value (supports decimal and hex with $ prefix) + var value int64 + var err error + + if strings.HasPrefix(arg, "$") { + _, err = fmt.Sscanf(arg[1:], "%x", &value) + } else { + _, err = fmt.Sscanf(arg, "%d", &value) + } + + if err != nil { + return fmt.Errorf("%s:%d: CALL %s: invalid numeric constant %q", line.Filename, line.LineNo, funcName, arg) + } + + if param.Symbol.IsByte() && (value < 0 || value > 255) { + return fmt.Errorf("%s:%d: CALL %s: constant %d out of byte range", line.Filename, line.LineNo, funcName, value) + } + + if value < 0 || value > 65535 { + return fmt.Errorf("%s:%d: CALL %s: constant %d out of word range", line.Filename, line.LineNo, funcName, value) + } + + lowByte := uint8(value & 0xFF) + highByte := uint8((value >> 8) & 0xFF) + + *inAssigns = append(*inAssigns, + fmt.Sprintf(" lda #%d", lowByte), + fmt.Sprintf(" sta %s", param.Symbol.FullName()), + ) + + if param.Symbol.IsWord() { + // Optimize: only reload A if high byte differs + if highByte != lowByte { + *inAssigns = append(*inAssigns, fmt.Sprintf(" lda #%d", highByte)) + } + *inAssigns = append(*inAssigns, fmt.Sprintf(" sta %s+1", param.Symbol.FullName())) + } + + return nil +} + +// parseImplicitDecl parses {BYTE varname} or {WORD varname} and adds to symbol table +func (fh *FunctionHandler) parseImplicitDecl(decl string, funcName string) error { + parts := strings.Fields(decl) + if len(parts) != 2 { + return fmt.Errorf("implicit declaration must be 'BYTE name' or 'WORD name', got: %q", decl) + } + + typeStr := strings.ToUpper(parts[0]) + varName := parts[1] + + var kind VarKind + switch typeStr { + case "BYTE": + kind = KindByte + case "WORD": + kind = KindWord + default: + return fmt.Errorf("implicit declaration type must be BYTE or WORD, got: %s", typeStr) + } + + // Add variable to symbol table with function scope + return fh.symTable.AddVar(varName, funcName, kind, 0) +} + +// EndFunction pops all functions from the stack (called by FEND) +func (fh *FunctionHandler) EndFunction() { + fh.currentFuncs = fh.currentFuncs[:0] +} + +// FuncExists checks if a function is declared +func (fh *FunctionHandler) FuncExists(name string) bool { + return fh.findFunc(name) != nil +} + +// CurrentFunction returns the current function name (or empty if global scope) +func (fh *FunctionHandler) CurrentFunction() string { + if len(fh.currentFuncs) == 0 { + return "" + } + return fh.currentFuncs[len(fh.currentFuncs)-1] +} + +// findFunc finds a function declaration by name +func (fh *FunctionHandler) findFunc(name string) *FuncDecl { + for _, f := range fh.functions { + if f.Name == name { + return f + } + } + return nil +} + +// fixIntuitiveFuncs normalizes function syntax +// Separates '(' and ')' into own tokens, removes commas +// Example: "func(a,b)" -> "func ( a b )" +func fixIntuitiveFuncs(s string) string { + var result strings.Builder + inString := false + + for i := 0; i < len(s); i++ { + ch := s[i] + + if ch == '"' { + inString = !inString + result.WriteByte(ch) + continue + } + + if !inString { + if ch == '(' || ch == ')' { + result.WriteByte(' ') + result.WriteByte(ch) + result.WriteByte(' ') + } else if ch == ',' { + result.WriteByte(' ') + } else { + result.WriteByte(ch) + } + } else { + result.WriteByte(ch) + } + } + + return normalizeSpaces(result.String()) +} + +// normalizeSpaces reduces multiple spaces to single space +func normalizeSpaces(s string) string { + s = strings.TrimSpace(s) + var result strings.Builder + inString := false + lastWasSpace := false + + for i := 0; i < len(s); i++ { + ch := s[i] + + if ch == '"' { + inString = !inString + result.WriteByte(ch) + lastWasSpace = false + continue + } + + if !inString { + if ch == ' ' || ch == '\t' { + if !lastWasSpace { + result.WriteByte(' ') + lastWasSpace = true + } + } else { + result.WriteByte(ch) + lastWasSpace = false + } + } else { + result.WriteByte(ch) + lastWasSpace = false + } + } + + return result.String() +} + +// parseParams splits line into space-separated parameters, respecting quoted strings +func parseParams(s string) ([]string, error) { + s = strings.TrimSpace(s) + if s == "" { + return []string{}, nil + } + + var params []string + var current strings.Builder + inString := false + + for i := 0; i < len(s); i++ { + ch := s[i] + + if ch == '"' { + inString = !inString + current.WriteByte(ch) + continue + } + + if !inString && (ch == ' ' || ch == '\t') { + if current.Len() > 0 { + params = append(params, current.String()) + current.Reset() + } + } else { + current.WriteByte(ch) + } + } + + if current.Len() > 0 { + params = append(params, current.String()) + } + + if inString { + return nil, fmt.Errorf("unterminated string in line") + } + + return params, nil +} + +// parseParamSpec parses a parameter specification +// Returns: direction, varName, isImplicit, implicitDecl, error +// Examples: +// +// "varname" -> DirIn, "varname", false, "", nil +// "in:varname" -> DirIn, "varname", false, "", nil +// "out:varname" -> DirOut, "varname", false, "", nil +// "io:varname" -> DirIn|DirOut, "varname", false, "", nil +// "{BYTE temp}" -> DirIn, "temp", true, "BYTE temp", nil +// "out:{WORD result}" -> DirOut, "result", true, "WORD result", nil +func parseParamSpec(spec string) (ParamDirection, string, bool, string, error) { + direction := DirIn // default + varName := spec + isImplicit := false + implicitDecl := "" + + // Check for direction prefix + if strings.Contains(spec, ":") { + parts := strings.SplitN(spec, ":", 2) + if len(parts) != 2 { + return 0, "", false, "", fmt.Errorf("invalid parameter spec: %q", spec) + } + + dirStr := strings.ToLower(parts[0]) + varName = parts[1] + + switch dirStr { + case "in": + direction = DirIn + case "out": + direction = DirOut + case "io": + direction = DirIn | DirOut + default: + return 0, "", false, "", fmt.Errorf("invalid parameter direction: %q", dirStr) + } + } + + // Check for implicit declaration {TYPE name} + if strings.HasPrefix(varName, "{") && strings.HasSuffix(varName, "}") { + isImplicit = true + implicitDecl = varName[1 : len(varName)-1] // strip { } + + // Extract variable name from implicit declaration + parts := strings.Fields(implicitDecl) + if len(parts) < 2 { + return 0, "", false, "", fmt.Errorf("invalid implicit declaration: %q", varName) + } + varName = parts[1] + } + + return direction, varName, isImplicit, implicitDecl, nil +} diff --git a/internal/compiler/funchandler_test.go b/internal/compiler/funchandler_test.go new file mode 100644 index 0000000..74f4a0a --- /dev/null +++ b/internal/compiler/funchandler_test.go @@ -0,0 +1,688 @@ +package compiler + +import ( + "strings" + "testing" + + "c65gm/internal/preproc" +) + +// Helper to create a test Line +func makeLine(text string) preproc.Line { + return preproc.Line{ + RawText: text, + Text: text, + Filename: "test.c65", + LineNo: 1, + Kind: preproc.Source, + PragmaSetIndex: 0, + } +} + +func TestFixIntuitiveFuncs(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"func(a,b)", "func ( a b )"}, + {"func( a, b )", "func ( a b )"}, + {"func(a,b,c)", "func ( a b c )"}, + {"CALL func()", "CALL func ( )"}, + {"func()", "func ( )"}, + {`func("hello",x)`, `func ( "hello" x )`}, + {`func("a,b",c)`, `func ( "a,b" c )`}, + {"func ( a , b )", "func ( a b )"}, + } + + for _, tt := range tests { + result := fixIntuitiveFuncs(tt.input) + if result != tt.expected { + t.Errorf("fixIntuitiveFuncs(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestBuildComplexParams(t *testing.T) { + tests := []struct { + input []string + expected []string + wantErr bool + }{ + { + input: []string{"a", "b", "c"}, + expected: []string{"a", "b", "c"}, + wantErr: false, + }, + { + input: []string{"{BYTE", "x}"}, + expected: []string{"{BYTE x}"}, + wantErr: false, + }, + { + input: []string{"{WORD", "ptr}"}, + expected: []string{"{WORD ptr}"}, + wantErr: false, + }, + { + input: []string{"{BYTE", "a}", "{WORD", "b}"}, + expected: []string{"{BYTE a}", "{WORD b}"}, + wantErr: false, + }, + { + input: []string{"x", "{BYTE", "a}", "y"}, + expected: []string{"x", "{BYTE a}", "y"}, + wantErr: false, + }, + { + input: []string{"{BYTE", "x"}, + expected: nil, + wantErr: true, // unclosed + }, + { + input: []string{"x}"}, + expected: nil, + wantErr: true, // unmatched close + }, + { + input: []string{"{BYTE", "{WORD", "x}"}, + expected: nil, + wantErr: true, // nested open + }, + } + + for _, tt := range tests { + result, err := buildComplexParams(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("buildComplexParams(%v) expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Errorf("buildComplexParams(%v) unexpected error: %v", tt.input, err) + continue + } + if len(result) != len(tt.expected) { + t.Errorf("buildComplexParams(%v) = %v, want %v", tt.input, result, tt.expected) + continue + } + for i := range result { + if result[i] != tt.expected[i] { + t.Errorf("buildComplexParams(%v)[%d] = %q, want %q", tt.input, i, result[i], tt.expected[i]) + } + } + } +} + +func TestParseParams(t *testing.T) { + tests := []struct { + input string + expected []string + wantErr bool + }{ + {"FUNC test", []string{"FUNC", "test"}, false}, + {"FUNC test ( a b )", []string{"FUNC", "test", "(", "a", "b", ")"}, false}, + {`CALL print ( "hello world" )`, []string{"CALL", "print", "(", `"hello world"`, ")"}, false}, + {" FUNC test ", []string{"FUNC", "test"}, false}, + {`func("unterminated`, nil, true}, + } + + for _, tt := range tests { + result, err := parseParams(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("parseParams(%q) expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Errorf("parseParams(%q) unexpected error: %v", tt.input, err) + continue + } + if len(result) != len(tt.expected) { + t.Errorf("parseParams(%q) = %v, want %v", tt.input, result, tt.expected) + continue + } + for i := range result { + if result[i] != tt.expected[i] { + t.Errorf("parseParams(%q)[%d] = %q, want %q", tt.input, i, result[i], tt.expected[i]) + } + } + } +} + +func TestParseParamSpec(t *testing.T) { + tests := []struct { + input string + wantDir ParamDirection + wantName string + wantImplicit bool + wantImplDecl string + wantErr bool + }{ + {"varname", DirIn, "varname", false, "", false}, + {"in:varname", DirIn, "varname", false, "", false}, + {"out:varname", DirOut, "varname", false, "", false}, + {"io:varname", DirIn | DirOut, "varname", false, "", false}, + {"{BYTE temp}", DirIn, "temp", true, "BYTE temp", false}, + {"{WORD result}", DirIn, "result", true, "WORD result", false}, + {"out:{BYTE x}", DirOut, "x", true, "BYTE x", false}, + {"io:{WORD ptr}", DirIn | DirOut, "ptr", true, "WORD ptr", false}, + {"invalid:dir:x", 0, "", false, "", true}, + } + + for _, tt := range tests { + dir, name, isImpl, implDecl, err := parseParamSpec(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("parseParamSpec(%q) expected error, got nil", tt.input) + } + continue + } + if err != nil { + t.Errorf("parseParamSpec(%q) unexpected error: %v", tt.input, err) + continue + } + if dir != tt.wantDir { + t.Errorf("parseParamSpec(%q) direction = %v, want %v", tt.input, dir, tt.wantDir) + } + if name != tt.wantName { + t.Errorf("parseParamSpec(%q) name = %q, want %q", tt.input, name, tt.wantName) + } + if isImpl != tt.wantImplicit { + t.Errorf("parseParamSpec(%q) implicit = %v, want %v", tt.input, isImpl, tt.wantImplicit) + } + if implDecl != tt.wantImplDecl { + t.Errorf("parseParamSpec(%q) implDecl = %q, want %q", tt.input, implDecl, tt.wantImplDecl) + } + } +} + +func TestHandleFuncDecl_VoidFunction(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + asm, err := fh.HandleFuncDecl(makeLine("FUNC test_void")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + if len(asm) != 1 { + t.Fatalf("expected 1 asm line, got %d", len(asm)) + } + if asm[0] != "test_void" { + t.Errorf("expected label 'test_void', got %q", asm[0]) + } + + if !fh.FuncExists("test_void") { + t.Error("function should exist") + } +} + +func TestHandleFuncDecl_WithExistingParams(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Pre-declare parameters + st.AddVar("x", "test_func", KindByte, 0) + st.AddVar("y", "test_func", KindWord, 0) + + asm, err := fh.HandleFuncDecl(makeLine("FUNC test_func ( x y )")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + if len(asm) != 1 { + t.Fatalf("expected 1 asm line, got %d", len(asm)) + } + + funcDecl := fh.findFunc("test_func") + if funcDecl == nil { + t.Fatal("function not found") + } + if len(funcDecl.Params) != 2 { + t.Fatalf("expected 2 params, got %d", len(funcDecl.Params)) + } +} + +func TestHandleFuncDecl_ImplicitDeclarations(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + asm, err := fh.HandleFuncDecl(makeLine("FUNC test_impl ( {BYTE a} {WORD b} )")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + if len(asm) != 1 { + t.Fatalf("expected 1 asm line, got %d", len(asm)) + } + + // Check that variables were declared + symA := st.Lookup("a", []string{"test_impl"}) + if symA == nil { + t.Fatal("parameter 'a' not declared") + } + if !symA.IsByte() { + t.Error("parameter 'a' should be byte") + } + + symB := st.Lookup("b", []string{"test_impl"}) + if symB == nil { + t.Fatal("parameter 'b' not declared") + } + if !symB.IsWord() { + t.Error("parameter 'b' should be word") + } +} + +func TestHandleFuncDecl_WithDirections(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + _, err := fh.HandleFuncDecl(makeLine("FUNC test_dir ( in:{BYTE a} out:{BYTE b} io:{WORD c} )")) + if err != nil { + t.Fatalf("HandleFuncDecl failed: %v", err) + } + + funcDecl := fh.findFunc("test_dir") + if funcDecl == nil { + t.Fatal("function not found") + } + + if len(funcDecl.Params) != 3 { + t.Fatalf("expected 3 params, got %d", len(funcDecl.Params)) + } + + if funcDecl.Params[0].Direction != DirIn { + t.Error("param 0 should be DirIn") + } + if funcDecl.Params[1].Direction != DirOut { + t.Error("param 1 should be DirOut") + } + if funcDecl.Params[2].Direction != (DirIn | DirOut) { + t.Error("param 2 should be DirIn|DirOut") + } +} + +func TestHandleFuncDecl_Errors(t *testing.T) { + tests := []struct { + name string + line string + preDecl func(*SymbolTable) + wantErr string + }{ + { + name: "redeclaration", + line: "FUNC duplicate ( {BYTE x} )", + preDecl: func(st *SymbolTable) {}, + wantErr: "already declared", + }, + { + name: "missing param", + line: "FUNC test ( missing )", + wantErr: "not declared", + }, + { + name: "const param", + line: "FUNC test ( constval )", + preDecl: func(st *SymbolTable) { + st.AddConst("constval", "test", KindByte, 42) + }, + wantErr: "cannot be a constant", + }, + { + name: "invalid implicit", + line: "FUNC test ( {INVALID x} )", + wantErr: "must be BYTE or WORD", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + if tt.preDecl != nil { + tt.preDecl(st) + } + + // Special case for redeclaration test + if tt.name == "redeclaration" { + fh.HandleFuncDecl(makeLine("FUNC duplicate ( {BYTE x} )")) + } + + _, err := fh.HandleFuncDecl(makeLine(tt.line)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestHandleFuncCall_VarArgs(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function with params + st.AddVar("param_a", "test_func", KindByte, 0) + st.AddVar("param_b", "test_func", KindWord, 0) + fh.HandleFuncDecl(makeLine("FUNC test_func ( param_a param_b )")) + + // Declare caller variables + st.AddVar("var_a", "", KindByte, 0) + st.AddVar("var_b", "", KindWord, 0) + + asm, err := fh.HandleFuncCall(makeLine("CALL test_func ( var_a var_b )")) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Check generated assembly + expectedLines := []string{ + " lda var_a", + " sta test_func_param_a", + " lda var_b", + " sta test_func_param_b", + " lda var_b+1", + " sta test_func_param_b+1", + " jsr test_func", + } + + if len(asm) != len(expectedLines) { + t.Fatalf("expected %d asm lines, got %d", len(expectedLines), len(asm)) + } + + for i, expected := range expectedLines { + if asm[i] != expected { + t.Errorf("asm[%d] = %q, want %q", i, asm[i], expected) + } + } +} + +func TestHandleFuncCall_OutParams(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function with out param + st.AddVar("result", "get_result", KindByte, 0) + fh.HandleFuncDecl(makeLine("FUNC get_result ( out:result )")) + + // Declare caller variable + st.AddVar("output", "", KindByte, 0) + + asm, err := fh.HandleFuncCall(makeLine("CALL get_result ( output )")) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Should have JSR and OUT assignment + found_jsr := false + found_out := false + for _, line := range asm { + if strings.Contains(line, "jsr get_result") { + found_jsr = true + } + if strings.Contains(line, "lda get_result_result") { + found_out = true + } + } + + if !found_jsr { + t.Error("missing jsr instruction") + } + if !found_out { + t.Error("missing out assignment") + } +} + +func TestHandleFuncCall_ConstArgs(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function + st.AddVar("x", "test_const", KindByte, 0) + st.AddVar("y", "test_const", KindWord, 0) + fh.HandleFuncDecl(makeLine("FUNC test_const ( x y )")) + + asm, err := fh.HandleFuncCall(makeLine("CALL test_const ( 42 $1234 )")) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Check for immediate loads + foundByte := false + foundWord := false + for _, line := range asm { + if strings.Contains(line, "lda #42") { + foundByte = true + } + if strings.Contains(line, "lda #18") { // 0x12 + foundWord = true + } + } + + if !foundByte { + t.Error("missing byte constant load") + } + if !foundWord { + t.Error("missing word constant load") + } +} + +func TestHandleFuncCall_LabelArg(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function + st.AddVar("ptr", "test_label", KindWord, 0) + fh.HandleFuncDecl(makeLine("FUNC test_label ( ptr )")) + + asm, err := fh.HandleFuncCall(makeLine("CALL test_label ( @my_label )")) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Check for label reference + foundLow := false + foundHigh := false + for _, line := range asm { + if strings.Contains(line, "#my_label") { + foundHigh = true + } + } + + if !foundLow || !foundHigh { + t.Error("missing label reference code") + } +} + +func TestHandleFuncCall_StringArg(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function + st.AddVar("str_ptr", "print", KindWord, 0) + fh.HandleFuncDecl(makeLine("FUNC print ( str_ptr )")) + + asm, err := fh.HandleFuncCall(makeLine(`CALL print ( "hello" )`)) + if err != nil { + t.Fatalf("HandleFuncCall failed: %v", err) + } + + // Check that label was generated + if ls.Size() != 1 { + t.Errorf("expected 1 label generated, got %d", ls.Size()) + } + + // Check for label reference in asm + foundLabel := false + for _, line := range asm { + if strings.Contains(line, "#L1") { + foundLabel = true + break + } + } + + if !foundLabel { + t.Error("missing string label reference") + } +} + +func TestHandleFuncCall_Errors(t *testing.T) { + tests := []struct { + name string + setup func(*FunctionHandler, *SymbolTable) + line string + wantErr string + }{ + { + name: "function not declared", + setup: func(fh *FunctionHandler, st *SymbolTable) {}, + line: "CALL undefined ( )", + wantErr: "not declared", + }, + { + name: "wrong arg count", + setup: func(fh *FunctionHandler, st *SymbolTable) { + st.AddVar("x", "test", KindByte, 0) + fh.HandleFuncDecl(makeLine("FUNC test ( x )")) + }, + line: "CALL test ( 1 2 )", + wantErr: "expected 1 arguments, got 2", + }, + { + name: "type mismatch", + setup: func(fh *FunctionHandler, st *SymbolTable) { + st.AddVar("param", "test", KindByte, 0) + fh.HandleFuncDecl(makeLine("FUNC test ( param )")) + st.AddVar("wvar", "", KindWord, 0) + }, + line: "CALL test ( wvar )", + wantErr: "type mismatch", + }, + { + name: "const to out param", + setup: func(fh *FunctionHandler, st *SymbolTable) { + st.AddVar("result", "test", KindByte, 0) + fh.HandleFuncDecl(makeLine("FUNC test ( out:result )")) + }, + line: "CALL test ( 42 )", + wantErr: "out/io parameter", + }, + { + name: "label to byte param", + setup: func(fh *FunctionHandler, st *SymbolTable) { + st.AddVar("x", "test", KindByte, 0) + fh.HandleFuncDecl(makeLine("FUNC test ( x )")) + }, + line: "CALL test ( @label )", + wantErr: "byte parameter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + tt.setup(fh, st) + + _, err := fh.HandleFuncCall(makeLine(tt.line)) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestEndFunction(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + // Declare function (pushes to stack) + fh.HandleFuncDecl(makeLine("FUNC test ( {BYTE x} )")) + + if fh.CurrentFunction() != "test" { + t.Errorf("current function = %q, want 'test'", fh.CurrentFunction()) + } + + // End function + fh.EndFunction() + + if fh.CurrentFunction() != "" { + t.Errorf("current function = %q, want ''", fh.CurrentFunction()) + } +} + +func TestCurrentFunction(t *testing.T) { + st := NewSymbolTable() + ls := NewLabelStack("L") + csh := NewConstantStringHandler() + pragma := preproc.NewPragma() + fh := NewFunctionHandler(st, ls, csh, pragma) + + if fh.CurrentFunction() != "" { + t.Error("expected empty current function initially") + } + + fh.HandleFuncDecl(makeLine("FUNC func1 ( {BYTE x} )")) + if fh.CurrentFunction() != "func1" { + t.Errorf("expected 'func1', got %q", fh.CurrentFunction()) + } + + fh.HandleFuncDecl(makeLine("FUNC func2 ( {BYTE y} )")) + if fh.CurrentFunction() != "func2" { + t.Errorf("expected 'func2', got %q", fh.CurrentFunction()) + } + + fh.EndFunction() + if fh.CurrentFunction() != "" { + t.Errorf("expected '', got %q", fh.CurrentFunction()) + } +}