package commands import ( "fmt" "os" "strings" "c65gm/internal/compiler" "c65gm/internal/preproc" "c65gm/internal/utils" ) // ShiftRCommand handles logical shift right operations // Syntax: // // SHIFTR BY GIVING # old syntax with BY/GIVING // SHIFTR >> -> # old syntax with >>/-> // = >> # new syntax type ShiftRCommand struct { sourceVarName string sourceVarKind compiler.VarKind sourceValue uint16 sourceIsVar bool amountVarName string amountVarKind compiler.VarKind amountValue uint16 amountIsVar bool destVarName string destVarKind compiler.VarKind line preproc.Line // Store line info for warnings } func (c *ShiftRCommand) WillHandle(line preproc.Line) bool { params, err := utils.ParseParams(line.Text) if err != nil || len(params) == 0 { return false } // Old syntax: SHIFTR ... (must have exactly 6 params) if strings.ToUpper(params[0]) == "SHIFTR" && len(params) == 6 { return true } // New syntax: = >> if len(params) == 5 && params[1] == "=" && params[3] == ">>" { return true } return false } func (c *ShiftRCommand) Interpret(line preproc.Line, ctx *compiler.CompilerContext) error { // Clear state c.sourceVarName = "" c.sourceIsVar = false c.sourceValue = 0 c.amountVarName = "" c.amountIsVar = false c.amountValue = 0 c.destVarName = "" c.line = line // Store line for warnings params, err := utils.ParseParams(line.Text) if err != nil { return err } paramCount := len(params) scope := ctx.CurrentScope() // Create constant lookup function constLookup := ctx.SymbolTable.ConstantLookupFunc(scope) // Determine syntax and parse accordingly if strings.ToUpper(params[0]) == "SHIFTR" { // Old syntax: SHIFTR BY/>> GIVING/-> if paramCount != 6 { return fmt.Errorf("SHIFTR: wrong number of parameters (%d), expected 6", paramCount) } separator1 := strings.ToUpper(params[2]) if separator1 != "BY" && separator1 != ">>" { return fmt.Errorf("SHIFTR: parameter #3 must be 'BY' or '>>', got %q", params[2]) } separator2 := strings.ToUpper(params[4]) if separator2 != "GIVING" && separator2 != "->" { return fmt.Errorf("SHIFTR: parameter #5 must be 'GIVING' or '->', got %q", params[4]) } // Parse destination destName := params[5] destSym := ctx.SymbolTable.Lookup(destName, scope) if destSym == nil { return fmt.Errorf("SHIFTR: unknown variable %q", destName) } if destSym.IsConst() { return fmt.Errorf("SHIFTR: cannot assign to constant %q", destName) } c.destVarName = destSym.FullName() c.destVarKind = destSym.GetVarKind() // Parse source var err error c.sourceVarName, c.sourceVarKind, c.sourceValue, c.sourceIsVar, err = compiler.ParseOperandParam( params[1], ctx.SymbolTable, scope, constLookup) if err != nil { return fmt.Errorf("SHIFTR: source: %w", err) } // Parse amount c.amountVarName, c.amountVarKind, c.amountValue, c.amountIsVar, err = compiler.ParseOperandParam( params[3], ctx.SymbolTable, scope, constLookup) if err != nil { return fmt.Errorf("SHIFTR: amount: %w", err) } } else { // New syntax: = >> if paramCount != 5 { return fmt.Errorf("SHIFTR: wrong number of parameters (%d), expected 5", paramCount) } if params[1] != "=" { return fmt.Errorf("SHIFTR: expected '=' at position 2, got %q", params[1]) } if params[3] != ">>" { return fmt.Errorf("SHIFTR: expected '>>' at position 4, got %q", params[3]) } // Parse destination destName := params[0] destSym := ctx.SymbolTable.Lookup(destName, scope) if destSym == nil { return fmt.Errorf("SHIFTR: unknown variable %q", destName) } if destSym.IsConst() { return fmt.Errorf("SHIFTR: cannot assign to constant %q", destName) } c.destVarName = destSym.FullName() c.destVarKind = destSym.GetVarKind() // Parse source var err error c.sourceVarName, c.sourceVarKind, c.sourceValue, c.sourceIsVar, err = compiler.ParseOperandParam( params[2], ctx.SymbolTable, scope, constLookup) if err != nil { return fmt.Errorf("SHIFTR: source: %w", err) } // Parse amount c.amountVarName, c.amountVarKind, c.amountValue, c.amountIsVar, err = compiler.ParseOperandParam( params[4], ctx.SymbolTable, scope, constLookup) if err != nil { return fmt.Errorf("SHIFTR: amount: %w", err) } } // Validate amount if c.amountIsVar { if c.amountVarKind == compiler.KindWord { return fmt.Errorf("SHIFTR: amount must be BYTE variable, got WORD %q", c.amountVarName) } } else { if c.amountValue > 255 { return fmt.Errorf("SHIFTR: amount constant %d out of BYTE range (0-255)", c.amountValue) } } return nil } func (c *ShiftRCommand) Generate(ctx *compiler.CompilerContext) ([]string, error) { var asm []string // Check if shift amount >= bit width (result will be zero) bitWidth := 8 if c.destVarKind == compiler.KindWord { bitWidth = 16 } amountZero := false if !c.amountIsVar { // Constant amount if c.amountValue >= uint16(bitWidth) { amountZero = true } } if amountZero { // Result is zero, just store zero if c.destVarKind == compiler.KindByte { asm = append(asm, "\tlda #0") asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } else { asm = append(asm, "\tlda #0") asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } return asm, nil } // Step 1: Copy source to destination if needed if c.sourceIsVar && c.sourceVarName == c.destVarName { // Same variable, no copy needed } else { copyAsm := c.generateCopy() asm = append(asm, copyAsm...) } // Step 2: Apply shift shiftAsm, err := c.generateShift(ctx) if err != nil { return nil, err } asm = append(asm, shiftAsm...) return asm, nil } // generateCopy generates assembly to copy source to destination func (c *ShiftRCommand) generateCopy() []string { var asm []string // If source is literal, just load it if !c.sourceIsVar { lo := uint8(c.sourceValue & 0xFF) hi := uint8((c.sourceValue >> 8) & 0xFF) if c.destVarKind == compiler.KindByte { asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } else { asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) if lo != hi { asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) } asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } return asm } // Source is variable if c.destVarKind == compiler.KindByte { // Destination is byte asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } else { // Destination is word if c.sourceVarKind == compiler.KindByte { // Byte -> Word (zero-extend) asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) asm = append(asm, "\tlda #0") asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } else { // Word -> Word asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } } return asm } // generateShift generates assembly to shift destination right by amount func (c *ShiftRCommand) generateShift(ctx *compiler.CompilerContext) ([]string, error) { var asm []string // Constant amount if !c.amountIsVar { amount := c.amountValue if amount == 0 { return asm, nil // No shift needed } // Determine bit width bitWidth := 8 if c.destVarKind == compiler.KindWord { bitWidth = 16 } // Warn if shift amount >= bit width (but not for 0) if amount >= uint16(bitWidth) { _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= %d bits, value will be zero\n", c.line.Filename, c.line.LineNo, amount, bitWidth) } if amount >= uint16(bitWidth) { // Shift all bits out -> zero if c.destVarKind == compiler.KindByte { asm = append(asm, "\tlda #0") asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } else { asm = append(asm, "\tlda #0") asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } return asm, nil } // Unroll shift loop for i := uint16(0); i < amount; i++ { if c.destVarKind == compiler.KindByte { asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) } else { asm = append(asm, fmt.Sprintf("\tlsr %s+1", c.destVarName)) asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) } } return asm, nil } // Variable amount // Generate labels loopLabel := ctx.GeneralStack.Push() ctx.GeneralStack.Pop() doneLabel := ctx.GeneralStack.Push() ctx.GeneralStack.Pop() // Load amount into X asm = append(asm, fmt.Sprintf("\tldx %s", c.amountVarName)) // Check for zero amount asm = append(asm, fmt.Sprintf("\tbeq %s", doneLabel)) // Shift loop asm = append(asm, loopLabel) if c.destVarKind == compiler.KindByte { asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) } else { asm = append(asm, fmt.Sprintf("\tlsr %s+1", c.destVarName)) asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) } asm = append(asm, "\tdex") asm = append(asm, fmt.Sprintf("\tbne %s", loopLabel)) asm = append(asm, doneLabel) return asm, nil }