From bea0640f0baebe025721be05d12dc7d19eaa3c49 Mon Sep 17 00:00:00 2001 From: Mattias Hansson Date: Tue, 4 Nov 2025 17:45:41 +0100 Subject: [PATCH] Added while break wend with a generic comparisonGenerator object --- internal/commands/and_test.go | 4 + internal/commands/break.go | 51 +++ internal/commands/comparison.go | 638 ++++++++++++++++++++++++++++++++ internal/commands/if.go | 8 - internal/commands/wend.go | 59 +++ internal/commands/while.go | 152 ++++++++ internal/commands/while_test.go | 419 +++++++++++++++++++++ internal/compiler/context.go | 4 +- main.go | 4 + 9 files changed, 1330 insertions(+), 9 deletions(-) create mode 100644 internal/commands/break.go create mode 100644 internal/commands/comparison.go create mode 100644 internal/commands/wend.go create mode 100644 internal/commands/while.go create mode 100644 internal/commands/while_test.go diff --git a/internal/commands/and_test.go b/internal/commands/and_test.go index 63045fc..8ae7a68 100644 --- a/internal/commands/and_test.go +++ b/internal/commands/and_test.go @@ -412,6 +412,7 @@ func TestAndCommand_NewSyntax(t *testing.T) { } } +/* // equalAsm compares two assembly slices for equality func equalAsm(a, b []string) bool { if len(a) != len(b) { @@ -424,3 +425,6 @@ func equalAsm(a, b []string) bool { } return true } + + +*/ diff --git a/internal/commands/break.go b/internal/commands/break.go new file mode 100644 index 0000000..1d89b2e --- /dev/null +++ b/internal/commands/break.go @@ -0,0 +1,51 @@ +package commands + +import ( + "fmt" + "strings" + + "c65gm/internal/compiler" + "c65gm/internal/preproc" + "c65gm/internal/utils" +) + +// BreakCommand handles BREAK statements +// Syntax: BREAK +// Exits current WHILE loop +type BreakCommand struct { + skipLabel string +} + +func (c *BreakCommand) WillHandle(line preproc.Line) bool { + params, err := utils.ParseParams(line.Text) + if err != nil || len(params) == 0 { + return false + } + return strings.ToUpper(params[0]) == "BREAK" +} + +func (c *BreakCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error { + params, err := utils.ParseParams(line.Text) + if err != nil { + return err + } + + if len(params) != 1 { + return fmt.Errorf("BREAK: expected 1 parameter, got %d", len(params)) + } + + // Get skip label from WHILE stack + var err2 error + c.skipLabel, err2 = ctx.WhileStack.Peek() + if err2 != nil { + return fmt.Errorf("BREAK: not inside WHILE loop") + } + + return nil +} + +func (c *BreakCommand) Generate(_ *compiler.CompilerContext) ([]string, error) { + return []string{ + fmt.Sprintf("\tjmp %s", c.skipLabel), + }, nil +} diff --git a/internal/commands/comparison.go b/internal/commands/comparison.go new file mode 100644 index 0000000..c4833c1 --- /dev/null +++ b/internal/commands/comparison.go @@ -0,0 +1,638 @@ +package commands + +import ( + "fmt" + + "c65gm/internal/compiler" +) + +// comparisonOp represents comparison operators +type comparisonOp int + +const ( + opEqual comparisonOp = iota + opNotEqual + opGreater + opLess + opGreaterEqual + opLessEqual +) + +// operandInfo describes an operand (variable or literal) +type operandInfo struct { + varName string + varKind compiler.VarKind + value uint16 + isVar bool +} + +// comparisonGenerator generates comparison code for control flow +type comparisonGenerator struct { + operator comparisonOp + param1 *operandInfo + param2 *operandInfo + useLongJump bool + compFlowStack *compiler.LabelStack + generalStack *compiler.LabelStack + + skipLabel string + kind1 compiler.VarKind + kind2 compiler.VarKind +} + +// newComparisonGenerator creates a new comparison generator +func newComparisonGenerator( + operator comparisonOp, + param1 *operandInfo, + param2 *operandInfo, + useLongJump bool, + compFlowStack *compiler.LabelStack, + generalStack *compiler.LabelStack, +) (*comparisonGenerator, error) { + if param1 == nil { + return nil, fmt.Errorf("param1 cannot be nil") + } + + // Handle single operand: convert to != 0 + if param2 == nil { + operator = opNotEqual + param2 = &operandInfo{value: 0, isVar: false} + } + + return &comparisonGenerator{ + operator: operator, + param1: param1, + param2: param2, + useLongJump: useLongJump, + compFlowStack: compFlowStack, + generalStack: generalStack, + }, nil +} + +// generate produces assembly that jumps to skip label on FALSE +func (cg *comparisonGenerator) generate() ([]string, error) { + var err error + cg.skipLabel, err = cg.compFlowStack.Peek() + if err != nil { + return nil, fmt.Errorf("comparison: %w", err) + } + + cg.kind1 = cg.getKind(cg.param1) + cg.kind2 = cg.getKind(cg.param2) + + // Constant folding + if !cg.param1.isVar && !cg.param2.isVar { + return cg.constantFold() + } + + // Dispatch to operator + switch cg.operator { + case opEqual: + return cg.genEqual() + case opNotEqual: + return cg.genNotEqual() + case opGreater: + return cg.genGreater() + case opLess: + return cg.genLess() + case opGreaterEqual: + return cg.genGreaterEqual() + case opLessEqual: + return cg.genLessEqual() + default: + return nil, fmt.Errorf("unsupported operator %d", cg.operator) + } +} + +// getKind returns the effective kind for an operand +func (cg *comparisonGenerator) getKind(op *operandInfo) compiler.VarKind { + if op.isVar { + return op.varKind + } + return inferKindFromValue(op.value) +} + +// constantFold evaluates comparisons between two literals +func (cg *comparisonGenerator) constantFold() ([]string, error) { + v1, v2 := cg.param1.value, cg.param2.value + + var result bool + switch cg.operator { + case opEqual: + result = v1 == v2 + case opNotEqual: + result = v1 != v2 + case opGreater: + result = v1 > v2 + case opLess: + result = v1 < v2 + case opGreaterEqual: + result = v1 >= v2 + case opLessEqual: + result = v1 <= v2 + default: + return nil, fmt.Errorf("operator not supported for constant folding") + } + + // If false, skip block + if !result { + return []string{fmt.Sprintf("\tjmp %s", cg.skipLabel)}, nil + } + return []string{}, nil +} + +// loadOperand generates LDA instruction for operand +func (cg *comparisonGenerator) loadOperand(op *operandInfo, offset int) string { + if op.isVar { + if offset > 0 { + return fmt.Sprintf("\tlda %s+%d", op.varName, offset) + } + return fmt.Sprintf("\tlda %s", op.varName) + } + + val := op.value + if offset == 1 { + val >>= 8 + } + return fmt.Sprintf("\tlda #$%02x", uint8(val)) +} + +// cmpOperand generates CMP instruction for operand +func (cg *comparisonGenerator) cmpOperand(op *operandInfo, offset int) string { + if op.isVar { + if offset > 0 { + return fmt.Sprintf("\tcmp %s+%d", op.varName, offset) + } + return fmt.Sprintf("\tcmp %s", op.varName) + } + + val := op.value + if offset == 1 { + val >>= 8 + } + return fmt.Sprintf("\tcmp #$%02x", uint8(val)) +} + +// tempLabel creates and returns a temporary label +func (cg *comparisonGenerator) tempLabel() string { + label := cg.generalStack.Push() + cg.generalStack.Pop() + return label +} + +// extractByteWord separates byte and word operands +func (cg *comparisonGenerator) extractByteWord() (byte, word *operandInfo) { + if cg.kind1 == compiler.KindByte { + return cg.param1, cg.param2 + } + return cg.param2, cg.param1 +} + +// == operator +func (cg *comparisonGenerator) genEqual() ([]string, error) { + // byte == byte + if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { + return cg.genByteEqual() + } + + // word == word + if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { + return cg.genWordEqual() + } + + // mixed + return cg.genMixedEqual() +} + +func (cg *comparisonGenerator) genByteEqual() ([]string, error) { + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbne %s", cg.skipLabel), + }, nil + } + + success := cg.tempLabel() + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbeq %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genWordEqual() ([]string, error) { + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbne %s", cg.skipLabel), + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbne %s", cg.skipLabel), + }, nil + } + + success := cg.tempLabel() + fail := cg.tempLabel() + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbne %s", fail), + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbeq %s", success), + fail, + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genMixedEqual() ([]string, error) { + byteOp, wordOp := cg.extractByteWord() + + if !cg.useLongJump { + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", cg.skipLabel), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbne %s", cg.skipLabel), + }, nil + } + + success := cg.tempLabel() + fail := cg.tempLabel() + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", fail), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbeq %s", success), + fail, + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +// != operator +func (cg *comparisonGenerator) genNotEqual() ([]string, error) { + if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { + return cg.genByteNotEqual() + } + + if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { + return cg.genWordNotEqual() + } + + return cg.genMixedNotEqual() +} + +func (cg *comparisonGenerator) genByteNotEqual() ([]string, error) { + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + }, nil + } + + success := cg.tempLabel() + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbne %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genWordNotEqual() ([]string, error) { + success := cg.tempLabel() + + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbne %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil + } + + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbne %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genMixedNotEqual() ([]string, error) { + byteOp, wordOp := cg.extractByteWord() + success := cg.tempLabel() + + if !cg.useLongJump { + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + success, + }, nil + } + + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbne %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +// > operator (unsigned) +func (cg *comparisonGenerator) genGreater() ([]string, error) { + if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { + return cg.genByteGreater() + } + + if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { + return cg.genWordGreater() + } + + return cg.genMixedGreater() +} + +func (cg *comparisonGenerator) genByteGreater() ([]string, error) { + // p1 > p2: skip if p1 <= p2 + // CMP sets C=1 if A >= operand + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), // equal means not greater + fmt.Sprintf("\tbcc %s", cg.skipLabel), // carry clear means less + }, nil + } + + success := cg.tempLabel() + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genWordGreater() ([]string, error) { + // Compare high bytes first + success := cg.tempLabel() + + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbcc %s", cg.skipLabel), // p1_hi < p2_hi + fmt.Sprintf("\tbne %s", success), // p1_hi > p2_hi + // High bytes equal, check low + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), // equal not greater + fmt.Sprintf("\tbcc %s", cg.skipLabel), // p1_lo < p2_lo + success, + }, nil + } + + return []string{ + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genMixedGreater() ([]string, error) { + byteOp, wordOp := cg.extractByteWord() + + // If byte is param1 and word is param2: byte > word only if word_hi=0 and byte > word_lo + // If word is param1 and byte is param2: word > byte if word_hi>0 OR word_lo > byte + if cg.kind1 == compiler.KindByte { + // byte > word + success := cg.tempLabel() + if !cg.useLongJump { + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", cg.skipLabel), // word too large + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + }, nil + } + + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", cg.skipLabel), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil + } + + // word > byte + success := cg.tempLabel() + if !cg.useLongJump { + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", success), // word_hi > 0 means always greater + cg.loadOperand(wordOp, 0), + cg.cmpOperand(byteOp, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + success, + }, nil + } + + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(wordOp, 0), + cg.cmpOperand(byteOp, 0), + fmt.Sprintf("\tbeq %s", cg.skipLabel), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +// < operator +func (cg *comparisonGenerator) genLess() ([]string, error) { + // p1 < p2 is equivalent to p2 > p1 + cg.param1, cg.param2 = cg.param2, cg.param1 + cg.kind1, cg.kind2 = cg.kind2, cg.kind1 + return cg.genGreater() +} + +// >= operator +func (cg *comparisonGenerator) genGreaterEqual() ([]string, error) { + if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { + return cg.genByteGreaterEqual() + } + + if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { + return cg.genWordGreaterEqual() + } + + return cg.genMixedGreaterEqual() +} + +func (cg *comparisonGenerator) genByteGreaterEqual() ([]string, error) { + // p1 >= p2: skip if p1 < p2 + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + }, nil + } + + success := cg.tempLabel() + return []string{ + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genWordGreaterEqual() ([]string, error) { + success := cg.tempLabel() + + if !cg.useLongJump { + return []string{ + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + success, + }, nil + } + + return []string{ + cg.loadOperand(cg.param1, 1), + cg.cmpOperand(cg.param2, 1), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(cg.param1, 0), + cg.cmpOperand(cg.param2, 0), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +func (cg *comparisonGenerator) genMixedGreaterEqual() ([]string, error) { + byteOp, wordOp := cg.extractByteWord() + + if cg.kind1 == compiler.KindByte { + // byte >= word + success := cg.tempLabel() + if !cg.useLongJump { + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", cg.skipLabel), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + }, nil + } + + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", cg.skipLabel), + cg.loadOperand(byteOp, 0), + cg.cmpOperand(wordOp, 0), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil + } + + // word >= byte + success := cg.tempLabel() + if !cg.useLongJump { + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(wordOp, 0), + cg.cmpOperand(byteOp, 0), + fmt.Sprintf("\tbcc %s", cg.skipLabel), + success, + }, nil + } + + return []string{ + cg.loadOperand(wordOp, 1), + "\tcmp #0", + fmt.Sprintf("\tbne %s", success), + cg.loadOperand(wordOp, 0), + cg.cmpOperand(byteOp, 0), + fmt.Sprintf("\tbcs %s", success), + fmt.Sprintf("\tjmp %s", cg.skipLabel), + success, + }, nil +} + +// <= operator +func (cg *comparisonGenerator) genLessEqual() ([]string, error) { + // p1 <= p2 is equivalent to p2 >= p1 + cg.param1, cg.param2 = cg.param2, cg.param1 + cg.kind1, cg.kind2 = cg.kind2, cg.kind1 + return cg.genGreaterEqual() +} + +func inferKindFromValue(val uint16) compiler.VarKind { + if val <= 255 { + return compiler.KindByte + } + return compiler.KindWord +} diff --git a/internal/commands/if.go b/internal/commands/if.go index f731837..e726647 100644 --- a/internal/commands/if.go +++ b/internal/commands/if.go @@ -627,11 +627,3 @@ func (c *IfCommand) generateNotEqualLongJump(ctx *compiler.CompilerContext) ([]s asm = append(asm, successLabel) return asm, nil } - -// inferKindFromValue determines if a literal value is byte or word -func inferKindFromValue(val uint16) compiler.VarKind { - if val <= 255 { - return compiler.KindByte - } - return compiler.KindWord -} diff --git a/internal/commands/wend.go b/internal/commands/wend.go new file mode 100644 index 0000000..6a5a3d2 --- /dev/null +++ b/internal/commands/wend.go @@ -0,0 +1,59 @@ +package commands + +import ( + "fmt" + "strings" + + "c65gm/internal/compiler" + "c65gm/internal/preproc" + "c65gm/internal/utils" +) + +// WendCommand handles WEND statements +// Syntax: WEND +// Ends current WHILE loop +type WendCommand struct { + loopLabel string + skipLabel string +} + +func (c *WendCommand) WillHandle(line preproc.Line) bool { + params, err := utils.ParseParams(line.Text) + if err != nil || len(params) == 0 { + return false + } + return strings.ToUpper(params[0]) == "WEND" +} + +func (c *WendCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error { + params, err := utils.ParseParams(line.Text) + if err != nil { + return err + } + + if len(params) != 1 { + return fmt.Errorf("WEND: expected 1 parameter, got %d", len(params)) + } + + // Pop loop label + var err2 error + c.loopLabel, err2 = ctx.LoopStack.Pop() + if err2 != nil { + return fmt.Errorf("WEND: not inside WHILE loop") + } + + // Pop skip label + c.skipLabel, err2 = ctx.WhileStack.Pop() + if err2 != nil { + return fmt.Errorf("WEND: not inside WHILE loop") + } + + return nil +} + +func (c *WendCommand) Generate(_ *compiler.CompilerContext) ([]string, error) { + return []string{ + fmt.Sprintf("\tjmp %s", c.loopLabel), + c.skipLabel, + }, nil +} diff --git a/internal/commands/while.go b/internal/commands/while.go new file mode 100644 index 0000000..7e44442 --- /dev/null +++ b/internal/commands/while.go @@ -0,0 +1,152 @@ +package commands + +import ( + "fmt" + "strings" + + "c65gm/internal/compiler" + "c65gm/internal/preproc" + "c65gm/internal/utils" +) + +// WhileCommand handles WHILE loop statements +// Syntax: WHILE +// Operators: =, ==, <>, !=, >, <, >=, <= +type WhileCommand struct { + operator string + param1 *operandInfo + param2 *operandInfo + useLongJump bool + loopLabel string + skipLabel string +} + +func (c *WhileCommand) WillHandle(line preproc.Line) bool { + params, err := utils.ParseParams(line.Text) + if err != nil || len(params) == 0 { + return false + } + return strings.ToUpper(params[0]) == "WHILE" +} + +func (c *WhileCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error { + params, err := utils.ParseParams(line.Text) + if err != nil { + return err + } + + if len(params) != 4 { + return fmt.Errorf("WHILE: expected 4 parameters, got %d", len(params)) + } + + c.operator = normalizeOperator(params[2]) + scope := ctx.CurrentScope() + + constLookup := func(name string) (int64, bool) { + sym := ctx.SymbolTable.Lookup(name, scope) + if sym != nil && sym.IsConst() { + return int64(sym.Value), true + } + return 0, false + } + + // Parse param1 + varName, varKind, value, isVar, err := compiler.ParseOperandParam( + params[1], ctx.SymbolTable, scope, constLookup) + if err != nil { + return fmt.Errorf("WHILE: param1: %w", err) + } + c.param1 = &operandInfo{ + varName: varName, + varKind: varKind, + value: value, + isVar: isVar, + } + + // Parse param2 + varName, varKind, value, isVar, err = compiler.ParseOperandParam( + params[3], ctx.SymbolTable, scope, constLookup) + if err != nil { + return fmt.Errorf("WHILE: param2: %w", err) + } + c.param2 = &operandInfo{ + varName: varName, + varKind: varKind, + value: value, + isVar: isVar, + } + + // Check pragma + ps := ctx.Pragma.GetPragmaSetByIndex(line.PragmaSetIndex) + longJumpPragma := ps.GetPragma("_P_USE_LONG_JUMP") + c.useLongJump = longJumpPragma != "" && longJumpPragma != "0" + + // Create labels + c.loopLabel = ctx.LoopStack.Push() + c.skipLabel = ctx.WhileStack.Push() + + return nil +} + +func (c *WhileCommand) Generate(ctx *compiler.CompilerContext) ([]string, error) { + op, err := parseOperator(c.operator) + if err != nil { + return nil, fmt.Errorf("WHILE: %w", err) + } + + // Emit loop label + asm := []string{c.loopLabel} + + // Generate comparison (jumps to skipLabel on false) + gen, err := newComparisonGenerator( + op, + c.param1, + c.param2, + c.useLongJump, + ctx.WhileStack, + ctx.GeneralStack, + ) + if err != nil { + return nil, fmt.Errorf("WHILE: %w", err) + } + + cmpAsm, err := gen.generate() + if err != nil { + return nil, fmt.Errorf("WHILE: %w", err) + } + + asm = append(asm, cmpAsm...) + return asm, nil +} + +// normalizeOperator converts operator variants to canonical form +func normalizeOperator(op string) string { + switch op { + case "==": + return "=" + case "!=": + return "<>" + default: + return op + } +} + +// parseOperator converts string operator to comparisonOp +func parseOperator(op string) (comparisonOp, error) { + switch op { + case "=": + return opEqual, nil + case "<>": + return opNotEqual, nil + case ">": + return opGreater, nil + case "<": + return opLess, nil + case ">=": + return opGreaterEqual, nil + case "<=": + return opLessEqual, nil + default: + return 0, fmt.Errorf("unsupported operator %q", op) + } +} diff --git a/internal/commands/while_test.go b/internal/commands/while_test.go new file mode 100644 index 0000000..39b1296 --- /dev/null +++ b/internal/commands/while_test.go @@ -0,0 +1,419 @@ +package commands + +import ( + "strings" + "testing" + + "c65gm/internal/compiler" + "c65gm/internal/preproc" +) + +func TestWhileBasicEqual(t *testing.T) { + tests := []struct { + name string + whileLine string + setupVars func(*compiler.SymbolTable) + wantWhile []string + wantWend []string + }{ + { + name: "byte var == byte literal", + whileLine: "WHILE x = 10", + setupVars: func(st *compiler.SymbolTable) { + st.AddVar("x", "", compiler.KindByte, 0) + }, + wantWhile: []string{ + "_LOOP1", + "\tlda x", + "\tcmp #$0a", + "\tbne _WEND1", + }, + wantWend: []string{ + "\tjmp _LOOP1", + "_WEND1", + }, + }, + { + name: "word var == word literal", + whileLine: "WHILE x = 1000", + setupVars: func(st *compiler.SymbolTable) { + st.AddVar("x", "", compiler.KindWord, 0) + }, + wantWhile: []string{ + "_LOOP1", + "\tlda x", + "\tcmp #$e8", + "\tbne _WEND1", + "\tlda x+1", + "\tcmp #$03", + "\tbne _WEND1", + }, + wantWend: []string{ + "\tjmp _LOOP1", + "_WEND1", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + tt.setupVars(ctx.SymbolTable) + + whileCmd := &WhileCommand{} + wendCmd := &WendCommand{} + + whileLine := preproc.Line{ + Text: tt.whileLine, + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + wendLine := preproc.Line{ + Text: "WEND", + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + if err := whileCmd.Interpret(whileLine, ctx); err != nil { + t.Fatalf("WHILE Interpret() error = %v", err) + } + + whileAsm, err := whileCmd.Generate(ctx) + if err != nil { + t.Fatalf("WHILE Generate() error = %v", err) + } + + if err := wendCmd.Interpret(wendLine, ctx); err != nil { + t.Fatalf("WEND Interpret() error = %v", err) + } + + wendAsm, err := wendCmd.Generate(ctx) + if err != nil { + t.Fatalf("WEND Generate() error = %v", err) + } + + if !equalAsm(whileAsm, tt.wantWhile) { + t.Errorf("WHILE Generate() mismatch\ngot:\n%s\nwant:\n%s", + strings.Join(whileAsm, "\n"), + strings.Join(tt.wantWhile, "\n")) + } + if !equalAsm(wendAsm, tt.wantWend) { + t.Errorf("WEND Generate() mismatch\ngot:\n%s\nwant:\n%s", + strings.Join(wendAsm, "\n"), + strings.Join(tt.wantWend, "\n")) + } + }) + } +} + +func TestWhileAllOperators(t *testing.T) { + tests := []struct { + name string + line string + wantInst string + }{ + {"equal", "WHILE x = 10", "bne"}, + {"not equal", "WHILE x <> 10", "beq"}, + {"greater", "WHILE x > 10", "beq"}, + {"less", "WHILE x < 10", "beq"}, + {"greater equal", "WHILE x >= 10", "bcc"}, + {"less equal", "WHILE x <= 10", "bcc"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0) + + cmd := &WhileCommand{} + line := preproc.Line{ + Text: tt.line, + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + if err := cmd.Interpret(line, ctx); err != nil { + t.Fatalf("Interpret() error = %v", err) + } + + asm, err := cmd.Generate(ctx) + if err != nil { + t.Fatalf("Generate() error = %v", err) + } + + found := false + for _, inst := range asm { + if strings.Contains(inst, tt.wantInst) { + found = true + break + } + } + if !found { + t.Errorf("Expected %s instruction not found in: %v", tt.wantInst, asm) + } + }) + } +} + +func TestWhileMixedTypes(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0) + ctx.SymbolTable.AddVar("y", "", compiler.KindWord, 0) + + cmd := &WhileCommand{} + line := preproc.Line{ + Text: "WHILE x < y", + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + if err := cmd.Interpret(line, ctx); err != nil { + t.Fatalf("Interpret() error = %v", err) + } + + asm, err := cmd.Generate(ctx) + if err != nil { + t.Fatalf("Generate() error = %v", err) + } + + foundHighByteCheck := false + for _, inst := range asm { + if strings.Contains(inst, "y+1") { + foundHighByteCheck = true + break + } + } + if !foundHighByteCheck { + t.Error("Expected high byte check for word in mixed comparison") + } +} + +func TestWhileBreak(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0) + + whileCmd := &WhileCommand{} + breakCmd := &BreakCommand{} + wendCmd := &WendCommand{} + + pragmaIdx := pragma.GetCurrentPragmaSetIndex() + + whileLine := preproc.Line{Text: "WHILE x < 10", Kind: preproc.Source, PragmaSetIndex: pragmaIdx} + breakLine := preproc.Line{Text: "BREAK", Kind: preproc.Source, PragmaSetIndex: pragmaIdx} + wendLine := preproc.Line{Text: "WEND", Kind: preproc.Source, PragmaSetIndex: pragmaIdx} + + if err := whileCmd.Interpret(whileLine, ctx); err != nil { + t.Fatalf("WHILE Interpret() error = %v", err) + } + + whileAsm, _ := whileCmd.Generate(ctx) + _ = whileAsm // body would go here + + if err := breakCmd.Interpret(breakLine, ctx); err != nil { + t.Fatalf("BREAK Interpret() error = %v", err) + } + + breakAsm, err := breakCmd.Generate(ctx) + if err != nil { + t.Fatalf("BREAK Generate() error = %v", err) + } + + if err := wendCmd.Interpret(wendLine, ctx); err != nil { + t.Fatalf("WEND Interpret() error = %v", err) + } + + if len(breakAsm) != 1 || !strings.Contains(breakAsm[0], "jmp _WEND") { + t.Errorf("BREAK should jump to WEND label, got: %v", breakAsm) + } +} + +func TestBreakOutsideLoop(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + cmd := &BreakCommand{} + line := preproc.Line{ + Text: "BREAK", + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + err := cmd.Interpret(line, ctx) + if err == nil { + t.Fatal("BREAK outside loop should fail") + } + if !strings.Contains(err.Error(), "not inside WHILE") { + t.Errorf("Wrong error message: %v", err) + } +} + +func TestWendWithoutWhile(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + cmd := &WendCommand{} + line := preproc.Line{ + Text: "WEND", + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + err := cmd.Interpret(line, ctx) + if err == nil { + t.Fatal("WEND without WHILE should fail") + } + if !strings.Contains(err.Error(), "not inside WHILE") { + t.Errorf("Wrong error message: %v", err) + } +} + +func TestWhileNested(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + ctx.SymbolTable.AddVar("i", "", compiler.KindByte, 0) + ctx.SymbolTable.AddVar("j", "", compiler.KindByte, 0) + + pragmaIdx := pragma.GetCurrentPragmaSetIndex() + + while1 := &WhileCommand{} + while2 := &WhileCommand{} + wend1 := &WendCommand{} + wend2 := &WendCommand{} + + if err := while1.Interpret(preproc.Line{Text: "WHILE i < 10", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil { + t.Fatalf("WHILE 1 error = %v", err) + } + asm1, err := while1.Generate(ctx) + if err != nil { + t.Fatalf("WHILE 1 Generate error = %v", err) + } + + if err := while2.Interpret(preproc.Line{Text: "WHILE j < 5", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil { + t.Fatalf("WHILE 2 error = %v", err) + } + asm2, err := while2.Generate(ctx) + if err != nil { + t.Fatalf("WHILE 2 Generate error = %v", err) + } + + if err := wend2.Interpret(preproc.Line{Text: "WEND", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil { + t.Fatalf("WEND 2 error = %v", err) + } + if err := wend1.Interpret(preproc.Line{Text: "WEND", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil { + t.Fatalf("WEND 1 error = %v", err) + } + + if asm1[0] == asm2[0] { + t.Error("Nested loops should have different labels") + } +} + +func TestWhileLongJump(t *testing.T) { + pragma := preproc.NewPragma() + pragma.AddPragma("_P_USE_LONG_JUMP", "1") + ctx := compiler.NewCompilerContext(pragma) + ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0) + + cmd := &WhileCommand{} + line := preproc.Line{ + Text: "WHILE x = 10", + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + if err := cmd.Interpret(line, ctx); err != nil { + t.Fatalf("Interpret() error = %v", err) + } + + asm, err := cmd.Generate(ctx) + if err != nil { + t.Fatalf("Generate() error = %v", err) + } + + foundJmp := false + for _, inst := range asm { + if strings.Contains(inst, "jmp") { + foundJmp = true + break + } + } + if !foundJmp { + t.Error("Long jump mode should contain JMP instruction") + } +} + +func TestWhileConstant(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + ctx.SymbolTable.AddConst("MAX", "", compiler.KindByte, 100) + ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0) + + cmd := &WhileCommand{} + line := preproc.Line{ + Text: "WHILE x < MAX", + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + if err := cmd.Interpret(line, ctx); err != nil { + t.Fatalf("Interpret() error = %v", err) + } + + asm, err := cmd.Generate(ctx) + if err != nil { + t.Fatalf("Generate() error = %v", err) + } + + found := false + for _, inst := range asm { + if strings.Contains(inst, "#$64") { + found = true + break + } + } + if !found { + t.Error("Constant should be folded to immediate value") + } +} + +func TestWhileWrongParamCount(t *testing.T) { + pragma := preproc.NewPragma() + ctx := compiler.NewCompilerContext(pragma) + + tests := []string{ + "WHILE x", + "WHILE x =", + "WHILE x = 10 extra", + } + + for _, text := range tests { + cmd := &WhileCommand{} + line := preproc.Line{ + Text: text, + Kind: preproc.Source, + PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), + } + + err := cmd.Interpret(line, ctx) + if err == nil { + t.Errorf("Should fail with wrong param count: %s", text) + } + } +} + +// Helper to compare assembly output +func equalAsm(got, want []string) bool { + if len(got) != len(want) { + return false + } + for i := range got { + if got[i] != want[i] { + return false + } + } + return true +} diff --git a/internal/compiler/context.go b/internal/compiler/context.go index 5e7a458..dc6c312 100644 --- a/internal/compiler/context.go +++ b/internal/compiler/context.go @@ -16,6 +16,7 @@ type CompilerContext struct { ConstStrHandler *ConstantStringHandler // Label stacks for control flow + LoopStack *LabelStack // Start of loop (like WHILE) WhileStack *LabelStack // WHILE...WEND IfStack *LabelStack // IF...ENDIF GeneralStack *LabelStack // General purpose (GOSUB, etc) @@ -33,7 +34,8 @@ func NewCompilerContext(pragma *preproc.Pragma) *CompilerContext { ctx := &CompilerContext{ SymbolTable: symTable, ConstStrHandler: constStrHandler, - WhileStack: NewLabelStack("_W"), + LoopStack: NewLabelStack("_LOOP"), + WhileStack: NewLabelStack("_WEND"), IfStack: NewLabelStack("_I"), GeneralStack: generalStack, Pragma: pragma, diff --git a/main.go b/main.go index 2cf7702..7e3d9a9 100644 --- a/main.go +++ b/main.go @@ -87,6 +87,10 @@ func registerCommands(comp *compiler.Compiler) { comp.Registry().Register(&commands.IfCommand{}) comp.Registry().Register(&commands.ElseCommand{}) comp.Registry().Register(&commands.EndIfCommand{}) + comp.Registry().Register(&commands.WhileCommand{}) + comp.Registry().Register(&commands.BreakCommand{}) + comp.Registry().Register(&commands.WendCommand{}) + } func writeOutput(filename string, lines []string) error {