Added while break wend with a generic comparisonGenerator object

This commit is contained in:
Mattias Hansson 2025-11-04 17:45:41 +01:00
parent a0e8bf40ea
commit bea0640f0b
9 changed files with 1330 additions and 9 deletions

View file

@ -412,6 +412,7 @@ func TestAndCommand_NewSyntax(t *testing.T) {
} }
} }
/*
// equalAsm compares two assembly slices for equality // equalAsm compares two assembly slices for equality
func equalAsm(a, b []string) bool { func equalAsm(a, b []string) bool {
if len(a) != len(b) { if len(a) != len(b) {
@ -424,3 +425,6 @@ func equalAsm(a, b []string) bool {
} }
return true return true
} }
*/

View file

@ -0,0 +1,51 @@
package commands
import (
"fmt"
"strings"
"c65gm/internal/compiler"
"c65gm/internal/preproc"
"c65gm/internal/utils"
)
// BreakCommand handles BREAK statements
// Syntax: BREAK
// Exits current WHILE loop
type BreakCommand struct {
skipLabel string
}
func (c *BreakCommand) WillHandle(line preproc.Line) bool {
params, err := utils.ParseParams(line.Text)
if err != nil || len(params) == 0 {
return false
}
return strings.ToUpper(params[0]) == "BREAK"
}
func (c *BreakCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error {
params, err := utils.ParseParams(line.Text)
if err != nil {
return err
}
if len(params) != 1 {
return fmt.Errorf("BREAK: expected 1 parameter, got %d", len(params))
}
// Get skip label from WHILE stack
var err2 error
c.skipLabel, err2 = ctx.WhileStack.Peek()
if err2 != nil {
return fmt.Errorf("BREAK: not inside WHILE loop")
}
return nil
}
func (c *BreakCommand) Generate(_ *compiler.CompilerContext) ([]string, error) {
return []string{
fmt.Sprintf("\tjmp %s", c.skipLabel),
}, nil
}

View file

@ -0,0 +1,638 @@
package commands
import (
"fmt"
"c65gm/internal/compiler"
)
// comparisonOp represents comparison operators
type comparisonOp int
const (
opEqual comparisonOp = iota
opNotEqual
opGreater
opLess
opGreaterEqual
opLessEqual
)
// operandInfo describes an operand (variable or literal)
type operandInfo struct {
varName string
varKind compiler.VarKind
value uint16
isVar bool
}
// comparisonGenerator generates comparison code for control flow
type comparisonGenerator struct {
operator comparisonOp
param1 *operandInfo
param2 *operandInfo
useLongJump bool
compFlowStack *compiler.LabelStack
generalStack *compiler.LabelStack
skipLabel string
kind1 compiler.VarKind
kind2 compiler.VarKind
}
// newComparisonGenerator creates a new comparison generator
func newComparisonGenerator(
operator comparisonOp,
param1 *operandInfo,
param2 *operandInfo,
useLongJump bool,
compFlowStack *compiler.LabelStack,
generalStack *compiler.LabelStack,
) (*comparisonGenerator, error) {
if param1 == nil {
return nil, fmt.Errorf("param1 cannot be nil")
}
// Handle single operand: convert to != 0
if param2 == nil {
operator = opNotEqual
param2 = &operandInfo{value: 0, isVar: false}
}
return &comparisonGenerator{
operator: operator,
param1: param1,
param2: param2,
useLongJump: useLongJump,
compFlowStack: compFlowStack,
generalStack: generalStack,
}, nil
}
// generate produces assembly that jumps to skip label on FALSE
func (cg *comparisonGenerator) generate() ([]string, error) {
var err error
cg.skipLabel, err = cg.compFlowStack.Peek()
if err != nil {
return nil, fmt.Errorf("comparison: %w", err)
}
cg.kind1 = cg.getKind(cg.param1)
cg.kind2 = cg.getKind(cg.param2)
// Constant folding
if !cg.param1.isVar && !cg.param2.isVar {
return cg.constantFold()
}
// Dispatch to operator
switch cg.operator {
case opEqual:
return cg.genEqual()
case opNotEqual:
return cg.genNotEqual()
case opGreater:
return cg.genGreater()
case opLess:
return cg.genLess()
case opGreaterEqual:
return cg.genGreaterEqual()
case opLessEqual:
return cg.genLessEqual()
default:
return nil, fmt.Errorf("unsupported operator %d", cg.operator)
}
}
// getKind returns the effective kind for an operand
func (cg *comparisonGenerator) getKind(op *operandInfo) compiler.VarKind {
if op.isVar {
return op.varKind
}
return inferKindFromValue(op.value)
}
// constantFold evaluates comparisons between two literals
func (cg *comparisonGenerator) constantFold() ([]string, error) {
v1, v2 := cg.param1.value, cg.param2.value
var result bool
switch cg.operator {
case opEqual:
result = v1 == v2
case opNotEqual:
result = v1 != v2
case opGreater:
result = v1 > v2
case opLess:
result = v1 < v2
case opGreaterEqual:
result = v1 >= v2
case opLessEqual:
result = v1 <= v2
default:
return nil, fmt.Errorf("operator not supported for constant folding")
}
// If false, skip block
if !result {
return []string{fmt.Sprintf("\tjmp %s", cg.skipLabel)}, nil
}
return []string{}, nil
}
// loadOperand generates LDA instruction for operand
func (cg *comparisonGenerator) loadOperand(op *operandInfo, offset int) string {
if op.isVar {
if offset > 0 {
return fmt.Sprintf("\tlda %s+%d", op.varName, offset)
}
return fmt.Sprintf("\tlda %s", op.varName)
}
val := op.value
if offset == 1 {
val >>= 8
}
return fmt.Sprintf("\tlda #$%02x", uint8(val))
}
// cmpOperand generates CMP instruction for operand
func (cg *comparisonGenerator) cmpOperand(op *operandInfo, offset int) string {
if op.isVar {
if offset > 0 {
return fmt.Sprintf("\tcmp %s+%d", op.varName, offset)
}
return fmt.Sprintf("\tcmp %s", op.varName)
}
val := op.value
if offset == 1 {
val >>= 8
}
return fmt.Sprintf("\tcmp #$%02x", uint8(val))
}
// tempLabel creates and returns a temporary label
func (cg *comparisonGenerator) tempLabel() string {
label := cg.generalStack.Push()
cg.generalStack.Pop()
return label
}
// extractByteWord separates byte and word operands
func (cg *comparisonGenerator) extractByteWord() (byte, word *operandInfo) {
if cg.kind1 == compiler.KindByte {
return cg.param1, cg.param2
}
return cg.param2, cg.param1
}
// == operator
func (cg *comparisonGenerator) genEqual() ([]string, error) {
// byte == byte
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteEqual()
}
// word == word
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordEqual()
}
// mixed
return cg.genMixedEqual()
}
func (cg *comparisonGenerator) genByteEqual() ([]string, error) {
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbne %s", cg.skipLabel),
}, nil
}
success := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbeq %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genWordEqual() ([]string, error) {
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbne %s", cg.skipLabel),
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbne %s", cg.skipLabel),
}, nil
}
success := cg.tempLabel()
fail := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbne %s", fail),
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbeq %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genMixedEqual() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", cg.skipLabel),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbne %s", cg.skipLabel),
}, nil
}
success := cg.tempLabel()
fail := cg.tempLabel()
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", fail),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbeq %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// != operator
func (cg *comparisonGenerator) genNotEqual() ([]string, error) {
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteNotEqual()
}
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordNotEqual()
}
return cg.genMixedNotEqual()
}
func (cg *comparisonGenerator) genByteNotEqual() ([]string, error) {
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
}, nil
}
success := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genWordNotEqual() ([]string, error) {
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genMixedNotEqual() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
success,
}, nil
}
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// > operator (unsigned)
func (cg *comparisonGenerator) genGreater() ([]string, error) {
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteGreater()
}
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordGreater()
}
return cg.genMixedGreater()
}
func (cg *comparisonGenerator) genByteGreater() ([]string, error) {
// p1 > p2: skip if p1 <= p2
// CMP sets C=1 if A >= operand
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel), // equal means not greater
fmt.Sprintf("\tbcc %s", cg.skipLabel), // carry clear means less
}, nil
}
success := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genWordGreater() ([]string, error) {
// Compare high bytes first
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", cg.skipLabel), // p1_hi < p2_hi
fmt.Sprintf("\tbne %s", success), // p1_hi > p2_hi
// High bytes equal, check low
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel), // equal not greater
fmt.Sprintf("\tbcc %s", cg.skipLabel), // p1_lo < p2_lo
success,
}, nil
}
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genMixedGreater() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
// If byte is param1 and word is param2: byte > word only if word_hi=0 and byte > word_lo
// If word is param1 and byte is param2: word > byte if word_hi>0 OR word_lo > byte
if cg.kind1 == compiler.KindByte {
// byte > word
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", cg.skipLabel), // word too large
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
}, nil
}
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", cg.skipLabel),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// word > byte
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", success), // word_hi > 0 means always greater
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
success,
}, nil
}
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// < operator
func (cg *comparisonGenerator) genLess() ([]string, error) {
// p1 < p2 is equivalent to p2 > p1
cg.param1, cg.param2 = cg.param2, cg.param1
cg.kind1, cg.kind2 = cg.kind2, cg.kind1
return cg.genGreater()
}
// >= operator
func (cg *comparisonGenerator) genGreaterEqual() ([]string, error) {
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteGreaterEqual()
}
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordGreaterEqual()
}
return cg.genMixedGreaterEqual()
}
func (cg *comparisonGenerator) genByteGreaterEqual() ([]string, error) {
// p1 >= p2: skip if p1 < p2
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
}, nil
}
success := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genWordGreaterEqual() ([]string, error) {
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
success,
}, nil
}
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genMixedGreaterEqual() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
if cg.kind1 == compiler.KindByte {
// byte >= word
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", cg.skipLabel),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
}, nil
}
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", cg.skipLabel),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// word >= byte
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
success,
}, nil
}
return []string{
cg.loadOperand(wordOp, 1),
"\tcmp #0",
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// <= operator
func (cg *comparisonGenerator) genLessEqual() ([]string, error) {
// p1 <= p2 is equivalent to p2 >= p1
cg.param1, cg.param2 = cg.param2, cg.param1
cg.kind1, cg.kind2 = cg.kind2, cg.kind1
return cg.genGreaterEqual()
}
func inferKindFromValue(val uint16) compiler.VarKind {
if val <= 255 {
return compiler.KindByte
}
return compiler.KindWord
}

View file

@ -627,11 +627,3 @@ func (c *IfCommand) generateNotEqualLongJump(ctx *compiler.CompilerContext) ([]s
asm = append(asm, successLabel) asm = append(asm, successLabel)
return asm, nil return asm, nil
} }
// inferKindFromValue determines if a literal value is byte or word
func inferKindFromValue(val uint16) compiler.VarKind {
if val <= 255 {
return compiler.KindByte
}
return compiler.KindWord
}

59
internal/commands/wend.go Normal file
View file

@ -0,0 +1,59 @@
package commands
import (
"fmt"
"strings"
"c65gm/internal/compiler"
"c65gm/internal/preproc"
"c65gm/internal/utils"
)
// WendCommand handles WEND statements
// Syntax: WEND
// Ends current WHILE loop
type WendCommand struct {
loopLabel string
skipLabel string
}
func (c *WendCommand) WillHandle(line preproc.Line) bool {
params, err := utils.ParseParams(line.Text)
if err != nil || len(params) == 0 {
return false
}
return strings.ToUpper(params[0]) == "WEND"
}
func (c *WendCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error {
params, err := utils.ParseParams(line.Text)
if err != nil {
return err
}
if len(params) != 1 {
return fmt.Errorf("WEND: expected 1 parameter, got %d", len(params))
}
// Pop loop label
var err2 error
c.loopLabel, err2 = ctx.LoopStack.Pop()
if err2 != nil {
return fmt.Errorf("WEND: not inside WHILE loop")
}
// Pop skip label
c.skipLabel, err2 = ctx.WhileStack.Pop()
if err2 != nil {
return fmt.Errorf("WEND: not inside WHILE loop")
}
return nil
}
func (c *WendCommand) Generate(_ *compiler.CompilerContext) ([]string, error) {
return []string{
fmt.Sprintf("\tjmp %s", c.loopLabel),
c.skipLabel,
}, nil
}

152
internal/commands/while.go Normal file
View file

@ -0,0 +1,152 @@
package commands
import (
"fmt"
"strings"
"c65gm/internal/compiler"
"c65gm/internal/preproc"
"c65gm/internal/utils"
)
// WhileCommand handles WHILE loop statements
// Syntax: WHILE <param1> <op> <param2>
// Operators: =, ==, <>, !=, >, <, >=, <=
type WhileCommand struct {
operator string
param1 *operandInfo
param2 *operandInfo
useLongJump bool
loopLabel string
skipLabel string
}
func (c *WhileCommand) WillHandle(line preproc.Line) bool {
params, err := utils.ParseParams(line.Text)
if err != nil || len(params) == 0 {
return false
}
return strings.ToUpper(params[0]) == "WHILE"
}
func (c *WhileCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error {
params, err := utils.ParseParams(line.Text)
if err != nil {
return err
}
if len(params) != 4 {
return fmt.Errorf("WHILE: expected 4 parameters, got %d", len(params))
}
c.operator = normalizeOperator(params[2])
scope := ctx.CurrentScope()
constLookup := func(name string) (int64, bool) {
sym := ctx.SymbolTable.Lookup(name, scope)
if sym != nil && sym.IsConst() {
return int64(sym.Value), true
}
return 0, false
}
// Parse param1
varName, varKind, value, isVar, err := compiler.ParseOperandParam(
params[1], ctx.SymbolTable, scope, constLookup)
if err != nil {
return fmt.Errorf("WHILE: param1: %w", err)
}
c.param1 = &operandInfo{
varName: varName,
varKind: varKind,
value: value,
isVar: isVar,
}
// Parse param2
varName, varKind, value, isVar, err = compiler.ParseOperandParam(
params[3], ctx.SymbolTable, scope, constLookup)
if err != nil {
return fmt.Errorf("WHILE: param2: %w", err)
}
c.param2 = &operandInfo{
varName: varName,
varKind: varKind,
value: value,
isVar: isVar,
}
// Check pragma
ps := ctx.Pragma.GetPragmaSetByIndex(line.PragmaSetIndex)
longJumpPragma := ps.GetPragma("_P_USE_LONG_JUMP")
c.useLongJump = longJumpPragma != "" && longJumpPragma != "0"
// Create labels
c.loopLabel = ctx.LoopStack.Push()
c.skipLabel = ctx.WhileStack.Push()
return nil
}
func (c *WhileCommand) Generate(ctx *compiler.CompilerContext) ([]string, error) {
op, err := parseOperator(c.operator)
if err != nil {
return nil, fmt.Errorf("WHILE: %w", err)
}
// Emit loop label
asm := []string{c.loopLabel}
// Generate comparison (jumps to skipLabel on false)
gen, err := newComparisonGenerator(
op,
c.param1,
c.param2,
c.useLongJump,
ctx.WhileStack,
ctx.GeneralStack,
)
if err != nil {
return nil, fmt.Errorf("WHILE: %w", err)
}
cmpAsm, err := gen.generate()
if err != nil {
return nil, fmt.Errorf("WHILE: %w", err)
}
asm = append(asm, cmpAsm...)
return asm, nil
}
// normalizeOperator converts operator variants to canonical form
func normalizeOperator(op string) string {
switch op {
case "==":
return "="
case "!=":
return "<>"
default:
return op
}
}
// parseOperator converts string operator to comparisonOp
func parseOperator(op string) (comparisonOp, error) {
switch op {
case "=":
return opEqual, nil
case "<>":
return opNotEqual, nil
case ">":
return opGreater, nil
case "<":
return opLess, nil
case ">=":
return opGreaterEqual, nil
case "<=":
return opLessEqual, nil
default:
return 0, fmt.Errorf("unsupported operator %q", op)
}
}

View file

@ -0,0 +1,419 @@
package commands
import (
"strings"
"testing"
"c65gm/internal/compiler"
"c65gm/internal/preproc"
)
func TestWhileBasicEqual(t *testing.T) {
tests := []struct {
name string
whileLine string
setupVars func(*compiler.SymbolTable)
wantWhile []string
wantWend []string
}{
{
name: "byte var == byte literal",
whileLine: "WHILE x = 10",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindByte, 0)
},
wantWhile: []string{
"_LOOP1",
"\tlda x",
"\tcmp #$0a",
"\tbne _WEND1",
},
wantWend: []string{
"\tjmp _LOOP1",
"_WEND1",
},
},
{
name: "word var == word literal",
whileLine: "WHILE x = 1000",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0)
},
wantWhile: []string{
"_LOOP1",
"\tlda x",
"\tcmp #$e8",
"\tbne _WEND1",
"\tlda x+1",
"\tcmp #$03",
"\tbne _WEND1",
},
wantWend: []string{
"\tjmp _LOOP1",
"_WEND1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
tt.setupVars(ctx.SymbolTable)
whileCmd := &WhileCommand{}
wendCmd := &WendCommand{}
whileLine := preproc.Line{
Text: tt.whileLine,
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
wendLine := preproc.Line{
Text: "WEND",
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
if err := whileCmd.Interpret(whileLine, ctx); err != nil {
t.Fatalf("WHILE Interpret() error = %v", err)
}
whileAsm, err := whileCmd.Generate(ctx)
if err != nil {
t.Fatalf("WHILE Generate() error = %v", err)
}
if err := wendCmd.Interpret(wendLine, ctx); err != nil {
t.Fatalf("WEND Interpret() error = %v", err)
}
wendAsm, err := wendCmd.Generate(ctx)
if err != nil {
t.Fatalf("WEND Generate() error = %v", err)
}
if !equalAsm(whileAsm, tt.wantWhile) {
t.Errorf("WHILE Generate() mismatch\ngot:\n%s\nwant:\n%s",
strings.Join(whileAsm, "\n"),
strings.Join(tt.wantWhile, "\n"))
}
if !equalAsm(wendAsm, tt.wantWend) {
t.Errorf("WEND Generate() mismatch\ngot:\n%s\nwant:\n%s",
strings.Join(wendAsm, "\n"),
strings.Join(tt.wantWend, "\n"))
}
})
}
}
func TestWhileAllOperators(t *testing.T) {
tests := []struct {
name string
line string
wantInst string
}{
{"equal", "WHILE x = 10", "bne"},
{"not equal", "WHILE x <> 10", "beq"},
{"greater", "WHILE x > 10", "beq"},
{"less", "WHILE x < 10", "beq"},
{"greater equal", "WHILE x >= 10", "bcc"},
{"less equal", "WHILE x <= 10", "bcc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0)
cmd := &WhileCommand{}
line := preproc.Line{
Text: tt.line,
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
if err := cmd.Interpret(line, ctx); err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
found := false
for _, inst := range asm {
if strings.Contains(inst, tt.wantInst) {
found = true
break
}
}
if !found {
t.Errorf("Expected %s instruction not found in: %v", tt.wantInst, asm)
}
})
}
}
func TestWhileMixedTypes(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0)
ctx.SymbolTable.AddVar("y", "", compiler.KindWord, 0)
cmd := &WhileCommand{}
line := preproc.Line{
Text: "WHILE x < y",
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
if err := cmd.Interpret(line, ctx); err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
foundHighByteCheck := false
for _, inst := range asm {
if strings.Contains(inst, "y+1") {
foundHighByteCheck = true
break
}
}
if !foundHighByteCheck {
t.Error("Expected high byte check for word in mixed comparison")
}
}
func TestWhileBreak(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0)
whileCmd := &WhileCommand{}
breakCmd := &BreakCommand{}
wendCmd := &WendCommand{}
pragmaIdx := pragma.GetCurrentPragmaSetIndex()
whileLine := preproc.Line{Text: "WHILE x < 10", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}
breakLine := preproc.Line{Text: "BREAK", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}
wendLine := preproc.Line{Text: "WEND", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}
if err := whileCmd.Interpret(whileLine, ctx); err != nil {
t.Fatalf("WHILE Interpret() error = %v", err)
}
whileAsm, _ := whileCmd.Generate(ctx)
_ = whileAsm // body would go here
if err := breakCmd.Interpret(breakLine, ctx); err != nil {
t.Fatalf("BREAK Interpret() error = %v", err)
}
breakAsm, err := breakCmd.Generate(ctx)
if err != nil {
t.Fatalf("BREAK Generate() error = %v", err)
}
if err := wendCmd.Interpret(wendLine, ctx); err != nil {
t.Fatalf("WEND Interpret() error = %v", err)
}
if len(breakAsm) != 1 || !strings.Contains(breakAsm[0], "jmp _WEND") {
t.Errorf("BREAK should jump to WEND label, got: %v", breakAsm)
}
}
func TestBreakOutsideLoop(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
cmd := &BreakCommand{}
line := preproc.Line{
Text: "BREAK",
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
err := cmd.Interpret(line, ctx)
if err == nil {
t.Fatal("BREAK outside loop should fail")
}
if !strings.Contains(err.Error(), "not inside WHILE") {
t.Errorf("Wrong error message: %v", err)
}
}
func TestWendWithoutWhile(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
cmd := &WendCommand{}
line := preproc.Line{
Text: "WEND",
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
err := cmd.Interpret(line, ctx)
if err == nil {
t.Fatal("WEND without WHILE should fail")
}
if !strings.Contains(err.Error(), "not inside WHILE") {
t.Errorf("Wrong error message: %v", err)
}
}
func TestWhileNested(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
ctx.SymbolTable.AddVar("i", "", compiler.KindByte, 0)
ctx.SymbolTable.AddVar("j", "", compiler.KindByte, 0)
pragmaIdx := pragma.GetCurrentPragmaSetIndex()
while1 := &WhileCommand{}
while2 := &WhileCommand{}
wend1 := &WendCommand{}
wend2 := &WendCommand{}
if err := while1.Interpret(preproc.Line{Text: "WHILE i < 10", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil {
t.Fatalf("WHILE 1 error = %v", err)
}
asm1, err := while1.Generate(ctx)
if err != nil {
t.Fatalf("WHILE 1 Generate error = %v", err)
}
if err := while2.Interpret(preproc.Line{Text: "WHILE j < 5", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil {
t.Fatalf("WHILE 2 error = %v", err)
}
asm2, err := while2.Generate(ctx)
if err != nil {
t.Fatalf("WHILE 2 Generate error = %v", err)
}
if err := wend2.Interpret(preproc.Line{Text: "WEND", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil {
t.Fatalf("WEND 2 error = %v", err)
}
if err := wend1.Interpret(preproc.Line{Text: "WEND", Kind: preproc.Source, PragmaSetIndex: pragmaIdx}, ctx); err != nil {
t.Fatalf("WEND 1 error = %v", err)
}
if asm1[0] == asm2[0] {
t.Error("Nested loops should have different labels")
}
}
func TestWhileLongJump(t *testing.T) {
pragma := preproc.NewPragma()
pragma.AddPragma("_P_USE_LONG_JUMP", "1")
ctx := compiler.NewCompilerContext(pragma)
ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0)
cmd := &WhileCommand{}
line := preproc.Line{
Text: "WHILE x = 10",
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
if err := cmd.Interpret(line, ctx); err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
foundJmp := false
for _, inst := range asm {
if strings.Contains(inst, "jmp") {
foundJmp = true
break
}
}
if !foundJmp {
t.Error("Long jump mode should contain JMP instruction")
}
}
func TestWhileConstant(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
ctx.SymbolTable.AddConst("MAX", "", compiler.KindByte, 100)
ctx.SymbolTable.AddVar("x", "", compiler.KindByte, 0)
cmd := &WhileCommand{}
line := preproc.Line{
Text: "WHILE x < MAX",
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
if err := cmd.Interpret(line, ctx); err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
found := false
for _, inst := range asm {
if strings.Contains(inst, "#$64") {
found = true
break
}
}
if !found {
t.Error("Constant should be folded to immediate value")
}
}
func TestWhileWrongParamCount(t *testing.T) {
pragma := preproc.NewPragma()
ctx := compiler.NewCompilerContext(pragma)
tests := []string{
"WHILE x",
"WHILE x =",
"WHILE x = 10 extra",
}
for _, text := range tests {
cmd := &WhileCommand{}
line := preproc.Line{
Text: text,
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
err := cmd.Interpret(line, ctx)
if err == nil {
t.Errorf("Should fail with wrong param count: %s", text)
}
}
}
// Helper to compare assembly output
func equalAsm(got, want []string) bool {
if len(got) != len(want) {
return false
}
for i := range got {
if got[i] != want[i] {
return false
}
}
return true
}

View file

@ -16,6 +16,7 @@ type CompilerContext struct {
ConstStrHandler *ConstantStringHandler ConstStrHandler *ConstantStringHandler
// Label stacks for control flow // Label stacks for control flow
LoopStack *LabelStack // Start of loop (like WHILE)
WhileStack *LabelStack // WHILE...WEND WhileStack *LabelStack // WHILE...WEND
IfStack *LabelStack // IF...ENDIF IfStack *LabelStack // IF...ENDIF
GeneralStack *LabelStack // General purpose (GOSUB, etc) GeneralStack *LabelStack // General purpose (GOSUB, etc)
@ -33,7 +34,8 @@ func NewCompilerContext(pragma *preproc.Pragma) *CompilerContext {
ctx := &CompilerContext{ ctx := &CompilerContext{
SymbolTable: symTable, SymbolTable: symTable,
ConstStrHandler: constStrHandler, ConstStrHandler: constStrHandler,
WhileStack: NewLabelStack("_W"), LoopStack: NewLabelStack("_LOOP"),
WhileStack: NewLabelStack("_WEND"),
IfStack: NewLabelStack("_I"), IfStack: NewLabelStack("_I"),
GeneralStack: generalStack, GeneralStack: generalStack,
Pragma: pragma, Pragma: pragma,

View file

@ -87,6 +87,10 @@ func registerCommands(comp *compiler.Compiler) {
comp.Registry().Register(&commands.IfCommand{}) comp.Registry().Register(&commands.IfCommand{})
comp.Registry().Register(&commands.ElseCommand{}) comp.Registry().Register(&commands.ElseCommand{})
comp.Registry().Register(&commands.EndIfCommand{}) comp.Registry().Register(&commands.EndIfCommand{})
comp.Registry().Register(&commands.WhileCommand{})
comp.Registry().Register(&commands.BreakCommand{})
comp.Registry().Register(&commands.WendCommand{})
} }
func writeOutput(filename string, lines []string) error { func writeOutput(filename string, lines []string) error {