c65gm/internal/commands/comparison.go

659 lines
16 KiB
Go

package commands
import (
"fmt"
"c65gm/internal/compiler"
)
// comparisonOp represents comparison operators
type comparisonOp int
const (
opEqual comparisonOp = iota
opNotEqual
opGreater
opLess
opGreaterEqual
opLessEqual
)
// operandInfo describes an operand (variable or literal)
type operandInfo struct {
varName string
varKind compiler.VarKind
value uint16
isVar bool
}
// comparisonGenerator generates comparison code for control flow
type comparisonGenerator struct {
operator comparisonOp
param1 *operandInfo
param2 *operandInfo
useLongJump bool
compFlowStack *compiler.LabelStack
generalStack *compiler.LabelStack
skipLabel string
kind1 compiler.VarKind
kind2 compiler.VarKind
}
// newComparisonGenerator creates a new comparison generator
func newComparisonGenerator(
operator comparisonOp,
param1 *operandInfo,
param2 *operandInfo,
useLongJump bool,
compFlowStack *compiler.LabelStack,
generalStack *compiler.LabelStack,
) (*comparisonGenerator, error) {
if param1 == nil {
return nil, fmt.Errorf("param1 cannot be nil")
}
// Handle single operand: convert to != 0
if param2 == nil {
operator = opNotEqual
param2 = &operandInfo{value: 0, isVar: false}
}
return &comparisonGenerator{
operator: operator,
param1: param1,
param2: param2,
useLongJump: useLongJump,
compFlowStack: compFlowStack,
generalStack: generalStack,
}, nil
}
// generate produces assembly that jumps to skip label on FALSE
func (cg *comparisonGenerator) generate() ([]string, error) {
var err error
cg.skipLabel, err = cg.compFlowStack.Peek()
if err != nil {
return nil, fmt.Errorf("comparison: %w", err)
}
cg.kind1 = cg.getKind(cg.param1)
cg.kind2 = cg.getKind(cg.param2)
// Constant folding
if !cg.param1.isVar && !cg.param2.isVar {
return cg.constantFold()
}
// Dispatch to operator
switch cg.operator {
case opEqual:
return cg.genEqual()
case opNotEqual:
return cg.genNotEqual()
case opGreater:
return cg.genGreater()
case opLess:
return cg.genLess()
case opGreaterEqual:
return cg.genGreaterEqual()
case opLessEqual:
return cg.genLessEqual()
default:
return nil, fmt.Errorf("unsupported operator %d", cg.operator)
}
}
// getKind returns the effective kind for an operand
func (cg *comparisonGenerator) getKind(op *operandInfo) compiler.VarKind {
if op.isVar {
return op.varKind
}
return inferKindFromValue(op.value)
}
// constantFold evaluates comparisons between two literals
func (cg *comparisonGenerator) constantFold() ([]string, error) {
v1, v2 := cg.param1.value, cg.param2.value
var result bool
switch cg.operator {
case opEqual:
result = v1 == v2
case opNotEqual:
result = v1 != v2
case opGreater:
result = v1 > v2
case opLess:
result = v1 < v2
case opGreaterEqual:
result = v1 >= v2
case opLessEqual:
result = v1 <= v2
default:
return nil, fmt.Errorf("operator not supported for constant folding")
}
// If false, skip block
if !result {
return []string{fmt.Sprintf("\tjmp %s", cg.skipLabel)}, nil
}
return []string{}, nil
}
// loadOperand generates LDA instruction for operand
func (cg *comparisonGenerator) loadOperand(op *operandInfo, offset int) string {
if op.isVar {
if offset > 0 {
return fmt.Sprintf("\tlda %s+%d", op.varName, offset)
}
return fmt.Sprintf("\tlda %s", op.varName)
}
val := op.value
if offset == 1 {
val >>= 8
}
return fmt.Sprintf("\tlda #$%02x", uint8(val))
}
// cmpOperand generates CMP instruction for operand
func (cg *comparisonGenerator) cmpOperand(op *operandInfo, offset int) string {
if op.isVar {
if offset > 0 {
return fmt.Sprintf("\tcmp %s+%d", op.varName, offset)
}
return fmt.Sprintf("\tcmp %s", op.varName)
}
val := op.value
if offset == 1 {
val >>= 8
}
return fmt.Sprintf("\tcmp #$%02x", uint8(val))
}
// tempLabel creates and returns a temporary label
func (cg *comparisonGenerator) tempLabel() string {
label := cg.generalStack.Push()
_, _ = cg.generalStack.Pop()
return label
}
// isZeroLiteral checks if operand is literal 0
func (cg *comparisonGenerator) isZeroLiteral(op *operandInfo) bool {
return !op.isVar && op.value == 0
}
// genLoadAndTest generates load with optional compare for zero optimization
func (cg *comparisonGenerator) genLoadAndTest(loadOp *operandInfo, cmpOp *operandInfo, offset int) []string {
code := []string{cg.loadOperand(loadOp, offset)}
// Skip CMP when comparing with zero - Z flag already set by LDA
if !cg.isZeroLiteral(cmpOp) {
code = append(code, cg.cmpOperand(cmpOp, offset))
}
return code
}
// extractByteWord separates byte and word operands
func (cg *comparisonGenerator) extractByteWord() (byte, word *operandInfo) {
if cg.kind1 == compiler.KindByte {
return cg.param1, cg.param2
}
return cg.param2, cg.param1
}
// == operator
func (cg *comparisonGenerator) genEqual() ([]string, error) {
// byte == byte
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteEqual()
}
// word == word
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordEqual()
}
// mixed
return cg.genMixedEqual()
}
func (cg *comparisonGenerator) genByteEqual() ([]string, error) {
code := cg.genLoadAndTest(cg.param1, cg.param2, 0)
if !cg.useLongJump {
code = append(code, fmt.Sprintf("\tbne %s", cg.skipLabel))
return code, nil
}
success := cg.tempLabel()
code = append(code,
fmt.Sprintf("\tbeq %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
func (cg *comparisonGenerator) genWordEqual() ([]string, error) {
code := cg.genLoadAndTest(cg.param1, cg.param2, 0)
if !cg.useLongJump {
code = append(code, fmt.Sprintf("\tbne %s", cg.skipLabel))
code = append(code, cg.genLoadAndTest(cg.param1, cg.param2, 1)...)
code = append(code, fmt.Sprintf("\tbne %s", cg.skipLabel))
return code, nil
}
success := cg.tempLabel()
fail := cg.tempLabel()
code = append(code, fmt.Sprintf("\tbne %s", fail))
code = append(code, cg.genLoadAndTest(cg.param1, cg.param2, 1)...)
code = append(code,
fmt.Sprintf("\tbeq %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
func (cg *comparisonGenerator) genMixedEqual() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
// Check if word high byte != 0 (means can't be equal)
code := []string{cg.loadOperand(wordOp, 1)}
code = append(code, fmt.Sprintf("\tbne %s", cg.skipLabel))
// High byte is 0, compare low bytes
code = append(code, cg.genLoadAndTest(wordOp, byteOp, 0)...)
if !cg.useLongJump {
code = append(code, fmt.Sprintf("\tbne %s", cg.skipLabel))
return code, nil
}
success := cg.tempLabel()
fail := cg.tempLabel()
code = append(code,
fmt.Sprintf("\tbeq %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
// != operator
func (cg *comparisonGenerator) genNotEqual() ([]string, error) {
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteNotEqual()
}
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordNotEqual()
}
return cg.genMixedNotEqual()
}
func (cg *comparisonGenerator) genByteNotEqual() ([]string, error) {
code := cg.genLoadAndTest(cg.param1, cg.param2, 0)
if !cg.useLongJump {
code = append(code, fmt.Sprintf("\tbeq %s", cg.skipLabel))
return code, nil
}
success := cg.tempLabel()
code = append(code,
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
func (cg *comparisonGenerator) genWordNotEqual() ([]string, error) {
success := cg.tempLabel()
code := cg.genLoadAndTest(cg.param1, cg.param2, 0)
if !cg.useLongJump {
code = append(code, fmt.Sprintf("\tbne %s", success))
code = append(code, cg.genLoadAndTest(cg.param1, cg.param2, 1)...)
code = append(code,
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
code = append(code, fmt.Sprintf("\tbne %s", success))
code = append(code, cg.genLoadAndTest(cg.param1, cg.param2, 1)...)
code = append(code,
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
func (cg *comparisonGenerator) genMixedNotEqual() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
success := cg.tempLabel()
// Check if word high byte != 0 (means values definitely not equal)
code := []string{cg.loadOperand(wordOp, 1)}
code = append(code, fmt.Sprintf("\tbne %s", success))
// High byte is 0, compare low bytes
code = append(code, cg.genLoadAndTest(wordOp, byteOp, 0)...)
if !cg.useLongJump {
code = append(code,
fmt.Sprintf("\tbeq %s", cg.skipLabel),
success,
)
return code, nil
}
code = append(code,
fmt.Sprintf("\tbne %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
)
return code, nil
}
// > operator (unsigned)
func (cg *comparisonGenerator) genGreater() ([]string, error) {
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteGreater()
}
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordGreater()
}
return cg.genMixedGreater()
}
func (cg *comparisonGenerator) genByteGreater() ([]string, error) {
// p1 > p2: skip if p1 <= p2
// CMP sets C=1 if A >= operand
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param2, 0),
cg.cmpOperand(cg.param1, 0),
fmt.Sprintf("\tbcs %s", cg.skipLabel), // skip if p2 >= p1 (i.e., p1 <= p2)
}, nil
}
success := cg.tempLabel()
fail := cg.tempLabel()
return []string{
cg.loadOperand(cg.param2, 0),
cg.cmpOperand(cg.param1, 0),
fmt.Sprintf("\tbcc %s", success), // if p2 < p1, success (p1 > p2)
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genWordGreater() ([]string, error) {
// Compare high bytes first
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", cg.skipLabel), // p1_hi < p2_hi
fmt.Sprintf("\tbne %s", success), // p1_hi > p2_hi
// High bytes equal, check low - swap operands for single branch
cg.loadOperand(cg.param2, 0),
cg.cmpOperand(cg.param1, 0),
fmt.Sprintf("\tbcs %s", cg.skipLabel), // skip if p2 >= p1 (i.e., p1 <= p2)
success,
}, nil
}
fail := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", fail),
fmt.Sprintf("\tbne %s", success),
// High bytes equal, check low - swap operands for single branch
cg.loadOperand(cg.param2, 0),
cg.cmpOperand(cg.param1, 0),
fmt.Sprintf("\tbcc %s", success), // if p2 < p1, success (p1 > p2)
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genMixedGreater() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
// If byte is param1 and word is param2: byte > word only if word_hi=0 and byte > word_lo
// If word is param1 and byte is param2: word > byte if word_hi>0 OR word_lo > byte
if cg.kind1 == compiler.KindByte {
// byte > word
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", cg.skipLabel), // word too large
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
}, nil
}
fail := cg.tempLabel()
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", fail),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbeq %s", fail),
fmt.Sprintf("\tbcs %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// word > byte
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", success), // word_hi > 0 means always greater
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbeq %s", cg.skipLabel),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
success,
}, nil
}
fail := cg.tempLabel()
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbeq %s", fail),
fmt.Sprintf("\tbcs %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// < operator
func (cg *comparisonGenerator) genLess() ([]string, error) {
// p1 < p2 is equivalent to p2 > p1
cg.param1, cg.param2 = cg.param2, cg.param1
cg.kind1, cg.kind2 = cg.kind2, cg.kind1
return cg.genGreater()
}
// >= operator
func (cg *comparisonGenerator) genGreaterEqual() ([]string, error) {
if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte {
return cg.genByteGreaterEqual()
}
if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord {
return cg.genWordGreaterEqual()
}
return cg.genMixedGreaterEqual()
}
func (cg *comparisonGenerator) genByteGreaterEqual() ([]string, error) {
// p1 >= p2: skip if p1 < p2
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
}, nil
}
success := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genWordGreaterEqual() ([]string, error) {
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
success,
}, nil
}
fail := cg.tempLabel()
return []string{
cg.loadOperand(cg.param1, 1),
cg.cmpOperand(cg.param2, 1),
fmt.Sprintf("\tbcc %s", fail),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(cg.param1, 0),
cg.cmpOperand(cg.param2, 0),
fmt.Sprintf("\tbcs %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
func (cg *comparisonGenerator) genMixedGreaterEqual() ([]string, error) {
byteOp, wordOp := cg.extractByteWord()
if cg.kind1 == compiler.KindByte {
// byte >= word
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", cg.skipLabel),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
}, nil
}
fail := cg.tempLabel()
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", fail),
cg.loadOperand(byteOp, 0),
cg.cmpOperand(wordOp, 0),
fmt.Sprintf("\tbcs %s", success),
fail,
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// word >= byte
success := cg.tempLabel()
if !cg.useLongJump {
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbcc %s", cg.skipLabel),
success,
}, nil
}
return []string{
cg.loadOperand(wordOp, 1),
fmt.Sprintf("\tbne %s", success),
cg.loadOperand(wordOp, 0),
cg.cmpOperand(byteOp, 0),
fmt.Sprintf("\tbcs %s", success),
fmt.Sprintf("\tjmp %s", cg.skipLabel),
success,
}, nil
}
// <= operator
func (cg *comparisonGenerator) genLessEqual() ([]string, error) {
// p1 <= p2 is equivalent to p2 >= p1
cg.param1, cg.param2 = cg.param2, cg.param1
cg.kind1, cg.kind2 = cg.kind2, cg.kind1
return cg.genGreaterEqual()
}
func inferKindFromValue(val uint16) compiler.VarKind {
if val <= 255 {
return compiler.KindByte
}
return compiler.KindWord
}
// parseOperator converts string operator to comparisonOp
func parseOperator(op string) (comparisonOp, error) {
switch op {
case "=", "==":
return opEqual, nil
case "<>", "!=":
return opNotEqual, nil
case ">":
return opGreater, nil
case "<":
return opLess, nil
case ">=":
return opGreaterEqual, nil
case "<=":
return opLessEqual, nil
default:
return 0, fmt.Errorf("unsupported operator %q", op)
}
}