Added funchandler.go + tests

This commit is contained in:
Mattias Hansson 2025-10-27 20:56:43 +01:00
parent 495aaeb6c0
commit 558dc8657c
2 changed files with 1345 additions and 0 deletions

View file

@ -0,0 +1,657 @@
package compiler
import (
"fmt"
"strings"
"c65gm/internal/preproc"
)
// ParamDirection represents parameter passing direction
type ParamDirection uint8
const (
DirIn ParamDirection = 1 << iota
DirOut
)
func (d ParamDirection) Has(flag ParamDirection) bool {
return d&flag != 0
}
// FuncParam represents a function parameter
type FuncParam struct {
Symbol *Symbol
Direction ParamDirection
}
// FuncDecl represents a function declaration
type FuncDecl struct {
Name string
Params []*FuncParam
}
// FunctionHandler manages function declarations and calls
type FunctionHandler struct {
functions []*FuncDecl
currentFuncs []string // stack of current function names (for nested scope)
symTable *SymbolTable
labelStack *LabelStack
constStrHandler *ConstantStringHandler
pragma *preproc.Pragma
}
// NewFunctionHandler creates a new function handler
func NewFunctionHandler(st *SymbolTable, ls *LabelStack, csh *ConstantStringHandler, pragma *preproc.Pragma) *FunctionHandler {
return &FunctionHandler{
functions: make([]*FuncDecl, 0),
currentFuncs: make([]string, 0),
symTable: st,
labelStack: ls,
constStrHandler: csh,
pragma: pragma,
}
}
// HandleFuncDecl parses and processes a FUNC declaration
// Syntax: FUNC name ( param1 param2 ... )
// Or: FUNC name (void function)
func (fh *FunctionHandler) HandleFuncDecl(line preproc.Line) ([]string, error) {
// Normalize parentheses and commas
text := fixIntuitiveFuncs(line.Text)
params, err := parseParams(text)
if err != nil {
return nil, fmt.Errorf("%s:%d: %w", line.Filename, line.LineNo, err)
}
if len(params) < 2 {
return nil, fmt.Errorf("%s:%d: FUNC: expected at least function name", line.Filename, line.LineNo)
}
if strings.ToUpper(params[0]) != "FUNC" {
return nil, fmt.Errorf("%s:%d: not a FUNC declaration", line.Filename, line.LineNo)
}
funcName := params[1]
// Check for redeclaration
if fh.FuncExists(funcName) {
return nil, fmt.Errorf("%s:%d: function %q already declared", line.Filename, line.LineNo, funcName)
}
// Push function name to current function stack early
// (so param declarations get correct scope)
fh.currentFuncs = append(fh.currentFuncs, funcName)
// Parse parameters
var funcParams []*FuncParam
if len(params) == 2 {
// Void function: FUNC name
// No parameters
} else if len(params) >= 5 {
// FUNC name ( param1 param2 )
if params[2] != "(" || params[len(params)-1] != ")" {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC: expected parentheses around parameters", line.Filename, line.LineNo)
}
// Extract params between ( and ) - need to handle {BYTE x} specially
rawParamTokens := params[3 : len(params)-1]
paramSpecs, err := buildComplexParams(rawParamTokens)
if err != nil {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC %s: %w", line.Filename, line.LineNo, funcName, err)
}
for _, spec := range paramSpecs {
direction, varName, isImplicit, implicitDecl, err := parseParamSpec(spec)
if err != nil {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC %s: %w", line.Filename, line.LineNo, funcName, err)
}
if isImplicit {
// Parse and add implicit variable declaration
// Format: {BYTE varname} or {WORD varname}
if err := fh.parseImplicitDecl(implicitDecl, funcName); err != nil {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC %s: implicit declaration: %w", line.Filename, line.LineNo, funcName, err)
}
}
// Look up variable in symbol table
sym := fh.symTable.Lookup(varName, []string{funcName})
if sym == nil {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC %s: parameter %q not declared", line.Filename, line.LineNo, funcName, varName)
}
if sym.IsConst() {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC %s: parameter %q cannot be a constant", line.Filename, line.LineNo, funcName, varName)
}
funcParams = append(funcParams, &FuncParam{
Symbol: sym,
Direction: direction,
})
}
} else {
fh.currentFuncs = fh.currentFuncs[:len(fh.currentFuncs)-1]
return nil, fmt.Errorf("%s:%d: FUNC: invalid syntax", line.Filename, line.LineNo)
}
// Store function declaration
fh.functions = append(fh.functions, &FuncDecl{
Name: funcName,
Params: funcParams,
})
// Generate assembler label
return []string{funcName}, nil
}
// buildComplexParams handles parameter lists that may contain {BYTE x} style declarations
// Tokens like {BYTE x} are spread across multiple tokens and need to be reassembled
func buildComplexParams(tokens []string) ([]string, error) {
var result []string
var current string
inBraces := false
for _, token := range tokens {
hasStart := strings.Contains(token, "{")
hasEnd := strings.Contains(token, "}")
if !inBraces {
// Not currently in braces
if hasEnd && !hasStart {
return nil, fmt.Errorf("unexpected } without matching {")
}
if hasStart {
// Starting a brace block
inBraces = true
current = token
// Check if it also ends on same token
if hasEnd {
result = append(result, current)
current = ""
inBraces = false
}
} else {
// Regular param
result = append(result, token)
}
} else {
// Currently accumulating in braces
if hasStart {
return nil, fmt.Errorf("unexpected { while already in braces")
}
current += " " + token
if hasEnd {
result = append(result, current)
current = ""
inBraces = false
}
}
}
if inBraces {
return nil, fmt.Errorf("unclosed { in parameter list")
}
return result, nil
}
// HandleFuncCall generates code for a function call
// Syntax: CALL funcname ( arg1 arg2 ... )
// Or: funcname ( arg1 arg2 ... )
func (fh *FunctionHandler) HandleFuncCall(line preproc.Line) ([]string, error) {
// Normalize parentheses and commas
text := fixIntuitiveFuncs(line.Text)
params, err := parseParams(text)
if err != nil {
return nil, fmt.Errorf("%s:%d: %w", line.Filename, line.LineNo, err)
}
if len(params) < 1 {
return nil, fmt.Errorf("%s:%d: CALL: empty line", line.Filename, line.LineNo)
}
// Check if starts with CALL keyword
startsWithCall := strings.ToUpper(params[0]) == "CALL"
funcNameIdx := 0
if startsWithCall {
if len(params) < 2 {
return nil, fmt.Errorf("%s:%d: CALL: expected function name", line.Filename, line.LineNo)
}
funcNameIdx = 1
}
funcName := params[funcNameIdx]
// Check if function exists
funcDecl := fh.findFunc(funcName)
if funcDecl == nil {
return nil, fmt.Errorf("%s:%d: function %q not declared", line.Filename, line.LineNo, funcName)
}
// Parse call arguments
var callArgs []string
if len(params) == funcNameIdx+1 {
// No arguments: funcname or CALL funcname
callArgs = []string{}
} else if len(params) >= funcNameIdx+4 {
// funcname ( arg1 arg2 ) or CALL funcname ( arg1 arg2 )
if params[funcNameIdx+1] != "(" || params[len(params)-1] != ")" {
return nil, fmt.Errorf("%s:%d: CALL %s: expected parentheses around arguments", line.Filename, line.LineNo, funcName)
}
callArgs = params[funcNameIdx+2 : len(params)-1]
} else {
return nil, fmt.Errorf("%s:%d: CALL %s: invalid syntax", line.Filename, line.LineNo, funcName)
}
// Check argument count matches
if len(callArgs) != len(funcDecl.Params) {
return nil, fmt.Errorf("%s:%d: CALL %s: expected %d arguments, got %d",
line.Filename, line.LineNo, funcName, len(funcDecl.Params), len(callArgs))
}
// Get pragma set for this line
pragmaSet := fh.pragma.GetPragmaSetByIndex(line.PragmaSetIndex)
var asmLines []string
var inAssigns []string
var outAssigns []string
// Process each argument
for i, arg := range callArgs {
param := funcDecl.Params[i]
// Handle different argument types
if strings.HasPrefix(arg, "@") {
// Label reference: @labelname
if err := fh.processLabelArg(arg, param, funcName, line, &inAssigns); err != nil {
return nil, err
}
} else if strings.HasPrefix(arg, "\"") && strings.HasSuffix(arg, "\"") {
// String constant
if err := fh.processStringArg(arg, param, funcName, line, pragmaSet, &inAssigns); err != nil {
return nil, err
}
} else if sym := fh.symTable.Lookup(arg, fh.currentFuncs); sym != nil {
// Variable reference
if err := fh.processVarArg(sym, param, funcName, line, &inAssigns, &outAssigns); err != nil {
return nil, err
}
} else {
// Numeric constant
if err := fh.processConstArg(arg, param, funcName, line, &inAssigns); err != nil {
return nil, err
}
}
}
// Generate final assembly
asmLines = append(asmLines, inAssigns...)
asmLines = append(asmLines, fmt.Sprintf(" jsr %s", funcName))
asmLines = append(asmLines, outAssigns...)
return asmLines, nil
}
// processLabelArg handles @label arguments
func (fh *FunctionHandler) processLabelArg(arg string, param *FuncParam, funcName string, line preproc.Line, inAssigns *[]string) error {
labelName := arg[1:] // strip @
if param.Symbol.IsByte() {
return fmt.Errorf("%s:%d: CALL %s: cannot pass label to byte parameter", line.Filename, line.LineNo, funcName)
}
if param.Direction.Has(DirOut) {
return fmt.Errorf("%s:%d: CALL %s: cannot pass label to out/io parameter", line.Filename, line.LineNo, funcName)
}
*inAssigns = append(*inAssigns,
fmt.Sprintf(" lda #<%s", labelName),
fmt.Sprintf(" sta %s", param.Symbol.FullName()),
fmt.Sprintf(" lda #>%s", labelName),
fmt.Sprintf(" sta %s+1", param.Symbol.FullName()),
)
return nil
}
// processStringArg handles "string" arguments
func (fh *FunctionHandler) processStringArg(arg string, param *FuncParam, funcName string, line preproc.Line, pragmaSet preproc.PragmaSet, inAssigns *[]string) error {
if param.Symbol.IsByte() {
return fmt.Errorf("%s:%d: CALL %s: cannot pass string to byte parameter", line.Filename, line.LineNo, funcName)
}
if param.Direction.Has(DirOut) {
return fmt.Errorf("%s:%d: CALL %s: cannot pass string to out/io parameter", line.Filename, line.LineNo, funcName)
}
// Generate label for string constant
labelName := fh.labelStack.Push()
fh.constStrHandler.AddConstStr(labelName, arg, true, pragmaSet)
*inAssigns = append(*inAssigns,
fmt.Sprintf(" lda #<%s", labelName),
fmt.Sprintf(" sta %s", param.Symbol.FullName()),
fmt.Sprintf(" lda #>%s", labelName),
fmt.Sprintf(" sta %s+1", param.Symbol.FullName()),
)
return nil
}
// processVarArg handles variable arguments
func (fh *FunctionHandler) processVarArg(sym *Symbol, param *FuncParam, funcName string, line preproc.Line, inAssigns, outAssigns *[]string) error {
// Check type compatibility
if (sym.IsByte() && param.Symbol.IsWord()) || (sym.IsWord() && param.Symbol.IsByte()) {
return fmt.Errorf("%s:%d: CALL %s: type mismatch for parameter %s", line.Filename, line.LineNo, funcName, param.Symbol.Name)
}
if sym.IsConst() {
return fmt.Errorf("%s:%d: CALL %s: cannot pass constant to function", line.Filename, line.LineNo, funcName)
}
// Generate IN assignments
if param.Direction.Has(DirIn) {
*inAssigns = append(*inAssigns,
fmt.Sprintf(" lda %s", sym.FullName()),
fmt.Sprintf(" sta %s", param.Symbol.FullName()),
)
if sym.IsWord() {
*inAssigns = append(*inAssigns,
fmt.Sprintf(" lda %s+1", sym.FullName()),
fmt.Sprintf(" sta %s+1", param.Symbol.FullName()),
)
}
}
// Generate OUT assignments
if param.Direction.Has(DirOut) {
*outAssigns = append(*outAssigns,
fmt.Sprintf(" lda %s", param.Symbol.FullName()),
fmt.Sprintf(" sta %s", sym.FullName()),
)
if sym.IsWord() {
*outAssigns = append(*outAssigns,
fmt.Sprintf(" lda %s+1", param.Symbol.FullName()),
fmt.Sprintf(" sta %s+1", sym.FullName()),
)
}
}
return nil
}
// processConstArg handles numeric constant arguments
func (fh *FunctionHandler) processConstArg(arg string, param *FuncParam, funcName string, line preproc.Line, inAssigns *[]string) error {
if param.Direction.Has(DirOut) {
return fmt.Errorf("%s:%d: CALL %s: cannot pass constant to out/io parameter", line.Filename, line.LineNo, funcName)
}
// Parse numeric value (supports decimal and hex with $ prefix)
var value int64
var err error
if strings.HasPrefix(arg, "$") {
_, err = fmt.Sscanf(arg[1:], "%x", &value)
} else {
_, err = fmt.Sscanf(arg, "%d", &value)
}
if err != nil {
return fmt.Errorf("%s:%d: CALL %s: invalid numeric constant %q", line.Filename, line.LineNo, funcName, arg)
}
if param.Symbol.IsByte() && (value < 0 || value > 255) {
return fmt.Errorf("%s:%d: CALL %s: constant %d out of byte range", line.Filename, line.LineNo, funcName, value)
}
if value < 0 || value > 65535 {
return fmt.Errorf("%s:%d: CALL %s: constant %d out of word range", line.Filename, line.LineNo, funcName, value)
}
lowByte := uint8(value & 0xFF)
highByte := uint8((value >> 8) & 0xFF)
*inAssigns = append(*inAssigns,
fmt.Sprintf(" lda #%d", lowByte),
fmt.Sprintf(" sta %s", param.Symbol.FullName()),
)
if param.Symbol.IsWord() {
// Optimize: only reload A if high byte differs
if highByte != lowByte {
*inAssigns = append(*inAssigns, fmt.Sprintf(" lda #%d", highByte))
}
*inAssigns = append(*inAssigns, fmt.Sprintf(" sta %s+1", param.Symbol.FullName()))
}
return nil
}
// parseImplicitDecl parses {BYTE varname} or {WORD varname} and adds to symbol table
func (fh *FunctionHandler) parseImplicitDecl(decl string, funcName string) error {
parts := strings.Fields(decl)
if len(parts) != 2 {
return fmt.Errorf("implicit declaration must be 'BYTE name' or 'WORD name', got: %q", decl)
}
typeStr := strings.ToUpper(parts[0])
varName := parts[1]
var kind VarKind
switch typeStr {
case "BYTE":
kind = KindByte
case "WORD":
kind = KindWord
default:
return fmt.Errorf("implicit declaration type must be BYTE or WORD, got: %s", typeStr)
}
// Add variable to symbol table with function scope
return fh.symTable.AddVar(varName, funcName, kind, 0)
}
// EndFunction pops all functions from the stack (called by FEND)
func (fh *FunctionHandler) EndFunction() {
fh.currentFuncs = fh.currentFuncs[:0]
}
// FuncExists checks if a function is declared
func (fh *FunctionHandler) FuncExists(name string) bool {
return fh.findFunc(name) != nil
}
// CurrentFunction returns the current function name (or empty if global scope)
func (fh *FunctionHandler) CurrentFunction() string {
if len(fh.currentFuncs) == 0 {
return ""
}
return fh.currentFuncs[len(fh.currentFuncs)-1]
}
// findFunc finds a function declaration by name
func (fh *FunctionHandler) findFunc(name string) *FuncDecl {
for _, f := range fh.functions {
if f.Name == name {
return f
}
}
return nil
}
// fixIntuitiveFuncs normalizes function syntax
// Separates '(' and ')' into own tokens, removes commas
// Example: "func(a,b)" -> "func ( a b )"
func fixIntuitiveFuncs(s string) string {
var result strings.Builder
inString := false
for i := 0; i < len(s); i++ {
ch := s[i]
if ch == '"' {
inString = !inString
result.WriteByte(ch)
continue
}
if !inString {
if ch == '(' || ch == ')' {
result.WriteByte(' ')
result.WriteByte(ch)
result.WriteByte(' ')
} else if ch == ',' {
result.WriteByte(' ')
} else {
result.WriteByte(ch)
}
} else {
result.WriteByte(ch)
}
}
return normalizeSpaces(result.String())
}
// normalizeSpaces reduces multiple spaces to single space
func normalizeSpaces(s string) string {
s = strings.TrimSpace(s)
var result strings.Builder
inString := false
lastWasSpace := false
for i := 0; i < len(s); i++ {
ch := s[i]
if ch == '"' {
inString = !inString
result.WriteByte(ch)
lastWasSpace = false
continue
}
if !inString {
if ch == ' ' || ch == '\t' {
if !lastWasSpace {
result.WriteByte(' ')
lastWasSpace = true
}
} else {
result.WriteByte(ch)
lastWasSpace = false
}
} else {
result.WriteByte(ch)
lastWasSpace = false
}
}
return result.String()
}
// parseParams splits line into space-separated parameters, respecting quoted strings
func parseParams(s string) ([]string, error) {
s = strings.TrimSpace(s)
if s == "" {
return []string{}, nil
}
var params []string
var current strings.Builder
inString := false
for i := 0; i < len(s); i++ {
ch := s[i]
if ch == '"' {
inString = !inString
current.WriteByte(ch)
continue
}
if !inString && (ch == ' ' || ch == '\t') {
if current.Len() > 0 {
params = append(params, current.String())
current.Reset()
}
} else {
current.WriteByte(ch)
}
}
if current.Len() > 0 {
params = append(params, current.String())
}
if inString {
return nil, fmt.Errorf("unterminated string in line")
}
return params, nil
}
// parseParamSpec parses a parameter specification
// Returns: direction, varName, isImplicit, implicitDecl, error
// Examples:
//
// "varname" -> DirIn, "varname", false, "", nil
// "in:varname" -> DirIn, "varname", false, "", nil
// "out:varname" -> DirOut, "varname", false, "", nil
// "io:varname" -> DirIn|DirOut, "varname", false, "", nil
// "{BYTE temp}" -> DirIn, "temp", true, "BYTE temp", nil
// "out:{WORD result}" -> DirOut, "result", true, "WORD result", nil
func parseParamSpec(spec string) (ParamDirection, string, bool, string, error) {
direction := DirIn // default
varName := spec
isImplicit := false
implicitDecl := ""
// Check for direction prefix
if strings.Contains(spec, ":") {
parts := strings.SplitN(spec, ":", 2)
if len(parts) != 2 {
return 0, "", false, "", fmt.Errorf("invalid parameter spec: %q", spec)
}
dirStr := strings.ToLower(parts[0])
varName = parts[1]
switch dirStr {
case "in":
direction = DirIn
case "out":
direction = DirOut
case "io":
direction = DirIn | DirOut
default:
return 0, "", false, "", fmt.Errorf("invalid parameter direction: %q", dirStr)
}
}
// Check for implicit declaration {TYPE name}
if strings.HasPrefix(varName, "{") && strings.HasSuffix(varName, "}") {
isImplicit = true
implicitDecl = varName[1 : len(varName)-1] // strip { }
// Extract variable name from implicit declaration
parts := strings.Fields(implicitDecl)
if len(parts) < 2 {
return 0, "", false, "", fmt.Errorf("invalid implicit declaration: %q", varName)
}
varName = parts[1]
}
return direction, varName, isImplicit, implicitDecl, nil
}

View file

@ -0,0 +1,688 @@
package compiler
import (
"strings"
"testing"
"c65gm/internal/preproc"
)
// Helper to create a test Line
func makeLine(text string) preproc.Line {
return preproc.Line{
RawText: text,
Text: text,
Filename: "test.c65",
LineNo: 1,
Kind: preproc.Source,
PragmaSetIndex: 0,
}
}
func TestFixIntuitiveFuncs(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"func(a,b)", "func ( a b )"},
{"func( a, b )", "func ( a b )"},
{"func(a,b,c)", "func ( a b c )"},
{"CALL func()", "CALL func ( )"},
{"func()", "func ( )"},
{`func("hello",x)`, `func ( "hello" x )`},
{`func("a,b",c)`, `func ( "a,b" c )`},
{"func ( a , b )", "func ( a b )"},
}
for _, tt := range tests {
result := fixIntuitiveFuncs(tt.input)
if result != tt.expected {
t.Errorf("fixIntuitiveFuncs(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestBuildComplexParams(t *testing.T) {
tests := []struct {
input []string
expected []string
wantErr bool
}{
{
input: []string{"a", "b", "c"},
expected: []string{"a", "b", "c"},
wantErr: false,
},
{
input: []string{"{BYTE", "x}"},
expected: []string{"{BYTE x}"},
wantErr: false,
},
{
input: []string{"{WORD", "ptr}"},
expected: []string{"{WORD ptr}"},
wantErr: false,
},
{
input: []string{"{BYTE", "a}", "{WORD", "b}"},
expected: []string{"{BYTE a}", "{WORD b}"},
wantErr: false,
},
{
input: []string{"x", "{BYTE", "a}", "y"},
expected: []string{"x", "{BYTE a}", "y"},
wantErr: false,
},
{
input: []string{"{BYTE", "x"},
expected: nil,
wantErr: true, // unclosed
},
{
input: []string{"x}"},
expected: nil,
wantErr: true, // unmatched close
},
{
input: []string{"{BYTE", "{WORD", "x}"},
expected: nil,
wantErr: true, // nested open
},
}
for _, tt := range tests {
result, err := buildComplexParams(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("buildComplexParams(%v) expected error, got nil", tt.input)
}
continue
}
if err != nil {
t.Errorf("buildComplexParams(%v) unexpected error: %v", tt.input, err)
continue
}
if len(result) != len(tt.expected) {
t.Errorf("buildComplexParams(%v) = %v, want %v", tt.input, result, tt.expected)
continue
}
for i := range result {
if result[i] != tt.expected[i] {
t.Errorf("buildComplexParams(%v)[%d] = %q, want %q", tt.input, i, result[i], tt.expected[i])
}
}
}
}
func TestParseParams(t *testing.T) {
tests := []struct {
input string
expected []string
wantErr bool
}{
{"FUNC test", []string{"FUNC", "test"}, false},
{"FUNC test ( a b )", []string{"FUNC", "test", "(", "a", "b", ")"}, false},
{`CALL print ( "hello world" )`, []string{"CALL", "print", "(", `"hello world"`, ")"}, false},
{" FUNC test ", []string{"FUNC", "test"}, false},
{`func("unterminated`, nil, true},
}
for _, tt := range tests {
result, err := parseParams(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("parseParams(%q) expected error, got nil", tt.input)
}
continue
}
if err != nil {
t.Errorf("parseParams(%q) unexpected error: %v", tt.input, err)
continue
}
if len(result) != len(tt.expected) {
t.Errorf("parseParams(%q) = %v, want %v", tt.input, result, tt.expected)
continue
}
for i := range result {
if result[i] != tt.expected[i] {
t.Errorf("parseParams(%q)[%d] = %q, want %q", tt.input, i, result[i], tt.expected[i])
}
}
}
}
func TestParseParamSpec(t *testing.T) {
tests := []struct {
input string
wantDir ParamDirection
wantName string
wantImplicit bool
wantImplDecl string
wantErr bool
}{
{"varname", DirIn, "varname", false, "", false},
{"in:varname", DirIn, "varname", false, "", false},
{"out:varname", DirOut, "varname", false, "", false},
{"io:varname", DirIn | DirOut, "varname", false, "", false},
{"{BYTE temp}", DirIn, "temp", true, "BYTE temp", false},
{"{WORD result}", DirIn, "result", true, "WORD result", false},
{"out:{BYTE x}", DirOut, "x", true, "BYTE x", false},
{"io:{WORD ptr}", DirIn | DirOut, "ptr", true, "WORD ptr", false},
{"invalid:dir:x", 0, "", false, "", true},
}
for _, tt := range tests {
dir, name, isImpl, implDecl, err := parseParamSpec(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("parseParamSpec(%q) expected error, got nil", tt.input)
}
continue
}
if err != nil {
t.Errorf("parseParamSpec(%q) unexpected error: %v", tt.input, err)
continue
}
if dir != tt.wantDir {
t.Errorf("parseParamSpec(%q) direction = %v, want %v", tt.input, dir, tt.wantDir)
}
if name != tt.wantName {
t.Errorf("parseParamSpec(%q) name = %q, want %q", tt.input, name, tt.wantName)
}
if isImpl != tt.wantImplicit {
t.Errorf("parseParamSpec(%q) implicit = %v, want %v", tt.input, isImpl, tt.wantImplicit)
}
if implDecl != tt.wantImplDecl {
t.Errorf("parseParamSpec(%q) implDecl = %q, want %q", tt.input, implDecl, tt.wantImplDecl)
}
}
}
func TestHandleFuncDecl_VoidFunction(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
asm, err := fh.HandleFuncDecl(makeLine("FUNC test_void"))
if err != nil {
t.Fatalf("HandleFuncDecl failed: %v", err)
}
if len(asm) != 1 {
t.Fatalf("expected 1 asm line, got %d", len(asm))
}
if asm[0] != "test_void" {
t.Errorf("expected label 'test_void', got %q", asm[0])
}
if !fh.FuncExists("test_void") {
t.Error("function should exist")
}
}
func TestHandleFuncDecl_WithExistingParams(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Pre-declare parameters
st.AddVar("x", "test_func", KindByte, 0)
st.AddVar("y", "test_func", KindWord, 0)
asm, err := fh.HandleFuncDecl(makeLine("FUNC test_func ( x y )"))
if err != nil {
t.Fatalf("HandleFuncDecl failed: %v", err)
}
if len(asm) != 1 {
t.Fatalf("expected 1 asm line, got %d", len(asm))
}
funcDecl := fh.findFunc("test_func")
if funcDecl == nil {
t.Fatal("function not found")
}
if len(funcDecl.Params) != 2 {
t.Fatalf("expected 2 params, got %d", len(funcDecl.Params))
}
}
func TestHandleFuncDecl_ImplicitDeclarations(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
asm, err := fh.HandleFuncDecl(makeLine("FUNC test_impl ( {BYTE a} {WORD b} )"))
if err != nil {
t.Fatalf("HandleFuncDecl failed: %v", err)
}
if len(asm) != 1 {
t.Fatalf("expected 1 asm line, got %d", len(asm))
}
// Check that variables were declared
symA := st.Lookup("a", []string{"test_impl"})
if symA == nil {
t.Fatal("parameter 'a' not declared")
}
if !symA.IsByte() {
t.Error("parameter 'a' should be byte")
}
symB := st.Lookup("b", []string{"test_impl"})
if symB == nil {
t.Fatal("parameter 'b' not declared")
}
if !symB.IsWord() {
t.Error("parameter 'b' should be word")
}
}
func TestHandleFuncDecl_WithDirections(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
_, err := fh.HandleFuncDecl(makeLine("FUNC test_dir ( in:{BYTE a} out:{BYTE b} io:{WORD c} )"))
if err != nil {
t.Fatalf("HandleFuncDecl failed: %v", err)
}
funcDecl := fh.findFunc("test_dir")
if funcDecl == nil {
t.Fatal("function not found")
}
if len(funcDecl.Params) != 3 {
t.Fatalf("expected 3 params, got %d", len(funcDecl.Params))
}
if funcDecl.Params[0].Direction != DirIn {
t.Error("param 0 should be DirIn")
}
if funcDecl.Params[1].Direction != DirOut {
t.Error("param 1 should be DirOut")
}
if funcDecl.Params[2].Direction != (DirIn | DirOut) {
t.Error("param 2 should be DirIn|DirOut")
}
}
func TestHandleFuncDecl_Errors(t *testing.T) {
tests := []struct {
name string
line string
preDecl func(*SymbolTable)
wantErr string
}{
{
name: "redeclaration",
line: "FUNC duplicate ( {BYTE x} )",
preDecl: func(st *SymbolTable) {},
wantErr: "already declared",
},
{
name: "missing param",
line: "FUNC test ( missing )",
wantErr: "not declared",
},
{
name: "const param",
line: "FUNC test ( constval )",
preDecl: func(st *SymbolTable) {
st.AddConst("constval", "test", KindByte, 42)
},
wantErr: "cannot be a constant",
},
{
name: "invalid implicit",
line: "FUNC test ( {INVALID x} )",
wantErr: "must be BYTE or WORD",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
if tt.preDecl != nil {
tt.preDecl(st)
}
// Special case for redeclaration test
if tt.name == "redeclaration" {
fh.HandleFuncDecl(makeLine("FUNC duplicate ( {BYTE x} )"))
}
_, err := fh.HandleFuncDecl(makeLine(tt.line))
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr)
}
})
}
}
func TestHandleFuncCall_VarArgs(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Declare function with params
st.AddVar("param_a", "test_func", KindByte, 0)
st.AddVar("param_b", "test_func", KindWord, 0)
fh.HandleFuncDecl(makeLine("FUNC test_func ( param_a param_b )"))
// Declare caller variables
st.AddVar("var_a", "", KindByte, 0)
st.AddVar("var_b", "", KindWord, 0)
asm, err := fh.HandleFuncCall(makeLine("CALL test_func ( var_a var_b )"))
if err != nil {
t.Fatalf("HandleFuncCall failed: %v", err)
}
// Check generated assembly
expectedLines := []string{
" lda var_a",
" sta test_func_param_a",
" lda var_b",
" sta test_func_param_b",
" lda var_b+1",
" sta test_func_param_b+1",
" jsr test_func",
}
if len(asm) != len(expectedLines) {
t.Fatalf("expected %d asm lines, got %d", len(expectedLines), len(asm))
}
for i, expected := range expectedLines {
if asm[i] != expected {
t.Errorf("asm[%d] = %q, want %q", i, asm[i], expected)
}
}
}
func TestHandleFuncCall_OutParams(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Declare function with out param
st.AddVar("result", "get_result", KindByte, 0)
fh.HandleFuncDecl(makeLine("FUNC get_result ( out:result )"))
// Declare caller variable
st.AddVar("output", "", KindByte, 0)
asm, err := fh.HandleFuncCall(makeLine("CALL get_result ( output )"))
if err != nil {
t.Fatalf("HandleFuncCall failed: %v", err)
}
// Should have JSR and OUT assignment
found_jsr := false
found_out := false
for _, line := range asm {
if strings.Contains(line, "jsr get_result") {
found_jsr = true
}
if strings.Contains(line, "lda get_result_result") {
found_out = true
}
}
if !found_jsr {
t.Error("missing jsr instruction")
}
if !found_out {
t.Error("missing out assignment")
}
}
func TestHandleFuncCall_ConstArgs(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Declare function
st.AddVar("x", "test_const", KindByte, 0)
st.AddVar("y", "test_const", KindWord, 0)
fh.HandleFuncDecl(makeLine("FUNC test_const ( x y )"))
asm, err := fh.HandleFuncCall(makeLine("CALL test_const ( 42 $1234 )"))
if err != nil {
t.Fatalf("HandleFuncCall failed: %v", err)
}
// Check for immediate loads
foundByte := false
foundWord := false
for _, line := range asm {
if strings.Contains(line, "lda #42") {
foundByte = true
}
if strings.Contains(line, "lda #18") { // 0x12
foundWord = true
}
}
if !foundByte {
t.Error("missing byte constant load")
}
if !foundWord {
t.Error("missing word constant load")
}
}
func TestHandleFuncCall_LabelArg(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Declare function
st.AddVar("ptr", "test_label", KindWord, 0)
fh.HandleFuncDecl(makeLine("FUNC test_label ( ptr )"))
asm, err := fh.HandleFuncCall(makeLine("CALL test_label ( @my_label )"))
if err != nil {
t.Fatalf("HandleFuncCall failed: %v", err)
}
// Check for label reference
foundLow := false
foundHigh := false
for _, line := range asm {
if strings.Contains(line, "#<my_label") {
foundLow = true
}
if strings.Contains(line, "#>my_label") {
foundHigh = true
}
}
if !foundLow || !foundHigh {
t.Error("missing label reference code")
}
}
func TestHandleFuncCall_StringArg(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Declare function
st.AddVar("str_ptr", "print", KindWord, 0)
fh.HandleFuncDecl(makeLine("FUNC print ( str_ptr )"))
asm, err := fh.HandleFuncCall(makeLine(`CALL print ( "hello" )`))
if err != nil {
t.Fatalf("HandleFuncCall failed: %v", err)
}
// Check that label was generated
if ls.Size() != 1 {
t.Errorf("expected 1 label generated, got %d", ls.Size())
}
// Check for label reference in asm
foundLabel := false
for _, line := range asm {
if strings.Contains(line, "#<L1") || strings.Contains(line, "#>L1") {
foundLabel = true
break
}
}
if !foundLabel {
t.Error("missing string label reference")
}
}
func TestHandleFuncCall_Errors(t *testing.T) {
tests := []struct {
name string
setup func(*FunctionHandler, *SymbolTable)
line string
wantErr string
}{
{
name: "function not declared",
setup: func(fh *FunctionHandler, st *SymbolTable) {},
line: "CALL undefined ( )",
wantErr: "not declared",
},
{
name: "wrong arg count",
setup: func(fh *FunctionHandler, st *SymbolTable) {
st.AddVar("x", "test", KindByte, 0)
fh.HandleFuncDecl(makeLine("FUNC test ( x )"))
},
line: "CALL test ( 1 2 )",
wantErr: "expected 1 arguments, got 2",
},
{
name: "type mismatch",
setup: func(fh *FunctionHandler, st *SymbolTable) {
st.AddVar("param", "test", KindByte, 0)
fh.HandleFuncDecl(makeLine("FUNC test ( param )"))
st.AddVar("wvar", "", KindWord, 0)
},
line: "CALL test ( wvar )",
wantErr: "type mismatch",
},
{
name: "const to out param",
setup: func(fh *FunctionHandler, st *SymbolTable) {
st.AddVar("result", "test", KindByte, 0)
fh.HandleFuncDecl(makeLine("FUNC test ( out:result )"))
},
line: "CALL test ( 42 )",
wantErr: "out/io parameter",
},
{
name: "label to byte param",
setup: func(fh *FunctionHandler, st *SymbolTable) {
st.AddVar("x", "test", KindByte, 0)
fh.HandleFuncDecl(makeLine("FUNC test ( x )"))
},
line: "CALL test ( @label )",
wantErr: "byte parameter",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
tt.setup(fh, st)
_, err := fh.HandleFuncCall(makeLine(tt.line))
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr)
}
})
}
}
func TestEndFunction(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
// Declare function (pushes to stack)
fh.HandleFuncDecl(makeLine("FUNC test ( {BYTE x} )"))
if fh.CurrentFunction() != "test" {
t.Errorf("current function = %q, want 'test'", fh.CurrentFunction())
}
// End function
fh.EndFunction()
if fh.CurrentFunction() != "" {
t.Errorf("current function = %q, want ''", fh.CurrentFunction())
}
}
func TestCurrentFunction(t *testing.T) {
st := NewSymbolTable()
ls := NewLabelStack("L")
csh := NewConstantStringHandler()
pragma := preproc.NewPragma()
fh := NewFunctionHandler(st, ls, csh, pragma)
if fh.CurrentFunction() != "" {
t.Error("expected empty current function initially")
}
fh.HandleFuncDecl(makeLine("FUNC func1 ( {BYTE x} )"))
if fh.CurrentFunction() != "func1" {
t.Errorf("expected 'func1', got %q", fh.CurrentFunction())
}
fh.HandleFuncDecl(makeLine("FUNC func2 ( {BYTE y} )"))
if fh.CurrentFunction() != "func2" {
t.Errorf("expected 'func2', got %q", fh.CurrentFunction())
}
fh.EndFunction()
if fh.CurrentFunction() != "" {
t.Errorf("expected '', got %q", fh.CurrentFunction())
}
}