package commands import ( "fmt" "c65gm/internal/compiler" ) // comparisonOp represents comparison operators type comparisonOp int const ( opEqual comparisonOp = iota opNotEqual opGreater opLess opGreaterEqual opLessEqual ) // operandInfo describes an operand (variable or literal) type operandInfo struct { varName string varKind compiler.VarKind value uint16 isVar bool } // comparisonGenerator generates comparison code for control flow type comparisonGenerator struct { operator comparisonOp param1 *operandInfo param2 *operandInfo useLongJump bool compFlowStack *compiler.LabelStack generalStack *compiler.LabelStack skipLabel string kind1 compiler.VarKind kind2 compiler.VarKind } // newComparisonGenerator creates a new comparison generator func newComparisonGenerator( operator comparisonOp, param1 *operandInfo, param2 *operandInfo, useLongJump bool, compFlowStack *compiler.LabelStack, generalStack *compiler.LabelStack, ) (*comparisonGenerator, error) { if param1 == nil { return nil, fmt.Errorf("param1 cannot be nil") } // Handle single operand: convert to != 0 if param2 == nil { operator = opNotEqual param2 = &operandInfo{value: 0, isVar: false} } return &comparisonGenerator{ operator: operator, param1: param1, param2: param2, useLongJump: useLongJump, compFlowStack: compFlowStack, generalStack: generalStack, }, nil } // generate produces assembly that jumps to skip label on FALSE func (cg *comparisonGenerator) generate() ([]string, error) { var err error cg.skipLabel, err = cg.compFlowStack.Peek() if err != nil { return nil, fmt.Errorf("comparison: %w", err) } cg.kind1 = cg.getKind(cg.param1) cg.kind2 = cg.getKind(cg.param2) // Constant folding if !cg.param1.isVar && !cg.param2.isVar { return cg.constantFold() } // Dispatch to operator switch cg.operator { case opEqual: return cg.genEqual() case opNotEqual: return cg.genNotEqual() case opGreater: return cg.genGreater() case opLess: return cg.genLess() case opGreaterEqual: return cg.genGreaterEqual() case opLessEqual: return cg.genLessEqual() default: return nil, fmt.Errorf("unsupported operator %d", cg.operator) } } // getKind returns the effective kind for an operand func (cg *comparisonGenerator) getKind(op *operandInfo) compiler.VarKind { if op.isVar { return op.varKind } return inferKindFromValue(op.value) } // constantFold evaluates comparisons between two literals func (cg *comparisonGenerator) constantFold() ([]string, error) { v1, v2 := cg.param1.value, cg.param2.value var result bool switch cg.operator { case opEqual: result = v1 == v2 case opNotEqual: result = v1 != v2 case opGreater: result = v1 > v2 case opLess: result = v1 < v2 case opGreaterEqual: result = v1 >= v2 case opLessEqual: result = v1 <= v2 default: return nil, fmt.Errorf("operator not supported for constant folding") } // If false, skip block if !result { return []string{fmt.Sprintf("\tjmp %s", cg.skipLabel)}, nil } return []string{}, nil } // loadOperand generates LDA instruction for operand func (cg *comparisonGenerator) loadOperand(op *operandInfo, offset int) string { if op.isVar { if offset > 0 { return fmt.Sprintf("\tlda %s+%d", op.varName, offset) } return fmt.Sprintf("\tlda %s", op.varName) } val := op.value if offset == 1 { val >>= 8 } return fmt.Sprintf("\tlda #$%02x", uint8(val)) } // cmpOperand generates CMP instruction for operand func (cg *comparisonGenerator) cmpOperand(op *operandInfo, offset int) string { if op.isVar { if offset > 0 { return fmt.Sprintf("\tcmp %s+%d", op.varName, offset) } return fmt.Sprintf("\tcmp %s", op.varName) } val := op.value if offset == 1 { val >>= 8 } return fmt.Sprintf("\tcmp #$%02x", uint8(val)) } // tempLabel creates and returns a temporary label func (cg *comparisonGenerator) tempLabel() string { label := cg.generalStack.Push() _, _ = cg.generalStack.Pop() return label } // extractByteWord separates byte and word operands func (cg *comparisonGenerator) extractByteWord() (byte, word *operandInfo) { if cg.kind1 == compiler.KindByte { return cg.param1, cg.param2 } return cg.param2, cg.param1 } // == operator func (cg *comparisonGenerator) genEqual() ([]string, error) { // byte == byte if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { return cg.genByteEqual() } // word == word if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { return cg.genWordEqual() } // mixed return cg.genMixedEqual() } func (cg *comparisonGenerator) genByteEqual() ([]string, error) { if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbne %s", cg.skipLabel), }, nil } success := cg.tempLabel() return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbeq %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genWordEqual() ([]string, error) { if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbne %s", cg.skipLabel), cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbne %s", cg.skipLabel), }, nil } success := cg.tempLabel() fail := cg.tempLabel() return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbne %s", fail), cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbeq %s", success), fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genMixedEqual() ([]string, error) { byteOp, wordOp := cg.extractByteWord() if !cg.useLongJump { return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", cg.skipLabel), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbne %s", cg.skipLabel), }, nil } success := cg.tempLabel() fail := cg.tempLabel() return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", fail), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbeq %s", success), fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } // != operator func (cg *comparisonGenerator) genNotEqual() ([]string, error) { if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { return cg.genByteNotEqual() } if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { return cg.genWordNotEqual() } return cg.genMixedNotEqual() } func (cg *comparisonGenerator) genByteNotEqual() ([]string, error) { if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbeq %s", cg.skipLabel), }, nil } success := cg.tempLabel() return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbne %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genWordNotEqual() ([]string, error) { success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbne %s", success), cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbne %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbne %s", success), cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbne %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genMixedNotEqual() ([]string, error) { byteOp, wordOp := cg.extractByteWord() success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", success), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbeq %s", cg.skipLabel), success, }, nil } return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", success), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbne %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } // > operator (unsigned) func (cg *comparisonGenerator) genGreater() ([]string, error) { if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { return cg.genByteGreater() } if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { return cg.genWordGreater() } return cg.genMixedGreater() } func (cg *comparisonGenerator) genByteGreater() ([]string, error) { // p1 > p2: skip if p1 <= p2 // CMP sets C=1 if A >= operand if !cg.useLongJump { return []string{ cg.loadOperand(cg.param2, 0), cg.cmpOperand(cg.param1, 0), fmt.Sprintf("\tbcs %s", cg.skipLabel), // skip if p2 >= p1 (i.e., p1 <= p2) }, nil } success := cg.tempLabel() fail := cg.tempLabel() return []string{ cg.loadOperand(cg.param2, 0), cg.cmpOperand(cg.param1, 0), fmt.Sprintf("\tbcc %s", success), // if p2 < p1, success (p1 > p2) fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genWordGreater() ([]string, error) { // Compare high bytes first success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbcc %s", cg.skipLabel), // p1_hi < p2_hi fmt.Sprintf("\tbne %s", success), // p1_hi > p2_hi // High bytes equal, check low - swap operands for single branch cg.loadOperand(cg.param2, 0), cg.cmpOperand(cg.param1, 0), fmt.Sprintf("\tbcs %s", cg.skipLabel), // skip if p2 >= p1 (i.e., p1 <= p2) success, }, nil } fail := cg.tempLabel() return []string{ cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbcc %s", fail), fmt.Sprintf("\tbne %s", success), // High bytes equal, check low - swap operands for single branch cg.loadOperand(cg.param2, 0), cg.cmpOperand(cg.param1, 0), fmt.Sprintf("\tbcc %s", success), // if p2 < p1, success (p1 > p2) fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genMixedGreater() ([]string, error) { byteOp, wordOp := cg.extractByteWord() // If byte is param1 and word is param2: byte > word only if word_hi=0 and byte > word_lo // If word is param1 and byte is param2: word > byte if word_hi>0 OR word_lo > byte if cg.kind1 == compiler.KindByte { // byte > word success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", cg.skipLabel), // word too large cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbeq %s", cg.skipLabel), fmt.Sprintf("\tbcc %s", cg.skipLabel), }, nil } fail := cg.tempLabel() return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", fail), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbeq %s", fail), fmt.Sprintf("\tbcs %s", success), fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } // word > byte success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", success), // word_hi > 0 means always greater cg.loadOperand(wordOp, 0), cg.cmpOperand(byteOp, 0), fmt.Sprintf("\tbeq %s", cg.skipLabel), fmt.Sprintf("\tbcc %s", cg.skipLabel), success, }, nil } fail := cg.tempLabel() return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", success), cg.loadOperand(wordOp, 0), cg.cmpOperand(byteOp, 0), fmt.Sprintf("\tbeq %s", fail), fmt.Sprintf("\tbcs %s", success), fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } // < operator func (cg *comparisonGenerator) genLess() ([]string, error) { // p1 < p2 is equivalent to p2 > p1 cg.param1, cg.param2 = cg.param2, cg.param1 cg.kind1, cg.kind2 = cg.kind2, cg.kind1 return cg.genGreater() } // >= operator func (cg *comparisonGenerator) genGreaterEqual() ([]string, error) { if cg.kind1 == compiler.KindByte && cg.kind2 == compiler.KindByte { return cg.genByteGreaterEqual() } if cg.kind1 == compiler.KindWord && cg.kind2 == compiler.KindWord { return cg.genWordGreaterEqual() } return cg.genMixedGreaterEqual() } func (cg *comparisonGenerator) genByteGreaterEqual() ([]string, error) { // p1 >= p2: skip if p1 < p2 if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbcc %s", cg.skipLabel), }, nil } success := cg.tempLabel() return []string{ cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbcs %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genWordGreaterEqual() ([]string, error) { success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbcc %s", cg.skipLabel), fmt.Sprintf("\tbne %s", success), cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbcc %s", cg.skipLabel), success, }, nil } fail := cg.tempLabel() return []string{ cg.loadOperand(cg.param1, 1), cg.cmpOperand(cg.param2, 1), fmt.Sprintf("\tbcc %s", fail), fmt.Sprintf("\tbne %s", success), cg.loadOperand(cg.param1, 0), cg.cmpOperand(cg.param2, 0), fmt.Sprintf("\tbcs %s", success), fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } func (cg *comparisonGenerator) genMixedGreaterEqual() ([]string, error) { byteOp, wordOp := cg.extractByteWord() if cg.kind1 == compiler.KindByte { // byte >= word success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", cg.skipLabel), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbcc %s", cg.skipLabel), }, nil } fail := cg.tempLabel() return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", fail), cg.loadOperand(byteOp, 0), cg.cmpOperand(wordOp, 0), fmt.Sprintf("\tbcs %s", success), fail, fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } // word >= byte success := cg.tempLabel() if !cg.useLongJump { return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", success), cg.loadOperand(wordOp, 0), cg.cmpOperand(byteOp, 0), fmt.Sprintf("\tbcc %s", cg.skipLabel), success, }, nil } return []string{ cg.loadOperand(wordOp, 1), fmt.Sprintf("\tbne %s", success), cg.loadOperand(wordOp, 0), cg.cmpOperand(byteOp, 0), fmt.Sprintf("\tbcs %s", success), fmt.Sprintf("\tjmp %s", cg.skipLabel), success, }, nil } // <= operator func (cg *comparisonGenerator) genLessEqual() ([]string, error) { // p1 <= p2 is equivalent to p2 >= p1 cg.param1, cg.param2 = cg.param2, cg.param1 cg.kind1, cg.kind2 = cg.kind2, cg.kind1 return cg.genGreaterEqual() } func inferKindFromValue(val uint16) compiler.VarKind { if val <= 255 { return compiler.KindByte } return compiler.KindWord } // parseOperator converts string operator to comparisonOp func parseOperator(op string) (comparisonOp, error) { switch op { case "=", "==": return opEqual, nil case "<>", "!=": return opNotEqual, nil case ">": return opGreater, nil case "<": return opLess, nil case ">=": return opGreaterEqual, nil case "<=": return opLessEqual, nil default: return 0, fmt.Errorf("unsupported operator %q", op) } }