From 822326f9932712dae0f1c1197a917240b8bb6e97 Mon Sep 17 00:00:00 2001 From: Mattias Hansson Date: Wed, 15 Apr 2026 03:03:35 +0200 Subject: [PATCH] Optimizations of shift operators --- internal/commands/shiftl.go | 354 +++++++++++++++-------- internal/commands/shiftl_test.go | 125 +++++++++ internal/commands/shiftr.go | 467 +++++++++++++++++++++++-------- internal/commands/shiftr_test.go | 130 +++++++++ 4 files changed, 843 insertions(+), 233 deletions(-) diff --git a/internal/commands/shiftl.go b/internal/commands/shiftl.go index b1b5004..85c6a21 100644 --- a/internal/commands/shiftl.go +++ b/internal/commands/shiftl.go @@ -175,84 +175,177 @@ func (c *ShiftLCommand) Interpret(line preproc.Line, ctx *compiler.CompilerConte } func (c *ShiftLCommand) Generate(ctx *compiler.CompilerContext) ([]string, error) { - var asm []string - - // Check if shift amount >= bit width (result will be zero) - bitWidth := 8 - if c.destVarKind == compiler.KindWord { - bitWidth = 16 + // Case 1: BYTE destination + if c.destVarKind == compiler.KindByte { + return c.generateByteShift(ctx) } - - amountZero := false - if !c.amountIsVar { - // Constant amount - if c.amountValue >= uint16(bitWidth) { - amountZero = true - } - } - - if amountZero { - // Result is zero, just store zero - if c.destVarKind == compiler.KindByte { - asm = append(asm, "\tlda #0") - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - } else { - asm = append(asm, "\tlda #0") - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) - } - return asm, nil - } - - // Step 1: Copy source to destination if needed - if c.sourceIsVar && c.sourceVarName == c.destVarName { - // Same variable, no copy needed - } else { - copyAsm := c.generateCopy() - asm = append(asm, copyAsm...) - } - - // Step 2: Apply shift - shiftAsm, err := c.generateShift(ctx) - if err != nil { - return nil, err - } - asm = append(asm, shiftAsm...) - - return asm, nil + + // Case 2: WORD destination + return c.generateWordShift(ctx) } -// generateCopy generates assembly to copy source to destination -func (c *ShiftLCommand) generateCopy() []string { +// generateByteShift handles BYTE << amount +func (c *ShiftLCommand) generateByteShift(ctx *compiler.CompilerContext) ([]string, error) { var asm []string - - // If source is literal, just load it - if !c.sourceIsVar { - lo := uint8(c.sourceValue & 0xFF) - hi := uint8((c.sourceValue >> 8) & 0xFF) - - if c.destVarKind == compiler.KindByte { - asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - } else { - asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - // Optimization: don't reload if lo == hi - if lo != hi { - asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) + + // Constant shift amount + if !c.amountIsVar { + amount := c.amountValue + + // Shift >= 8 bits -> zero + if amount >= 8 { + if amount > 0 { + _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= 8 bits, value will be zero\n", + c.line.Filename, c.line.LineNo, amount) } - asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) + + // Store zero + if c.sourceIsVar && c.sourceVarName == c.destVarName { + // Same variable: just zero it + asm = append(asm, "\tlda #0") + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } else { + // Different: load zero and store + asm = append(asm, "\tlda #0") + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } + return asm, nil } - return asm + + // Shift 0 -> copy source + if amount == 0 { + return c.generateByteCopy(), nil + } + + // Constant shift 1-7 + return c.generateByteShiftConst(amount), nil } + + // Variable shift amount + return c.generateByteShiftVar(ctx) +} - // Source is variable - if c.destVarKind == compiler.KindByte { - // Destination is byte +// generateByteCopy copies source to destination (BYTE) +func (c *ShiftLCommand) generateByteCopy() []string { + var asm []string + + if c.sourceIsVar { + if c.sourceVarName == c.destVarName { + // Same variable, no copy needed + return asm + } + // Different variable asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } else { - // Destination is word + // Literal + val := uint8(c.sourceValue & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", val)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } + + return asm +} + +// generateByteShiftConst generates BYTE << constant (1-7) +func (c *ShiftLCommand) generateByteShiftConst(amount uint16) []string { + var asm []string + + // Copy source to destination if needed + copyAsm := c.generateByteCopy() + asm = append(asm, copyAsm...) + + // Apply shift + for i := uint16(0); i < amount; i++ { + asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) + } + + return asm +} + +// generateByteShiftVar generates BYTE << variable +func (c *ShiftLCommand) generateByteShiftVar(ctx *compiler.CompilerContext) ([]string, error) { + var asm []string + + // Copy source to destination if needed + copyAsm := c.generateByteCopy() + asm = append(asm, copyAsm...) + + // Generate labels + loopLabel := ctx.GeneralStack.Push() + ctx.GeneralStack.Pop() + doneLabel := ctx.GeneralStack.Push() + ctx.GeneralStack.Pop() + + // Variable shift loop + asm = append(asm, fmt.Sprintf("\tldx %s", c.amountVarName)) + asm = append(asm, fmt.Sprintf("\tbeq %s", doneLabel)) + asm = append(asm, loopLabel) + asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) + asm = append(asm, "\tdex") + asm = append(asm, fmt.Sprintf("\tbne %s", loopLabel)) + asm = append(asm, doneLabel) + + return asm, nil +} + +// generateWordShift handles WORD << amount +func (c *ShiftLCommand) generateWordShift(ctx *compiler.CompilerContext) ([]string, error) { + // Constant shift amount + if !c.amountIsVar { + amount := c.amountValue + + // Shift >= 16 bits -> zero + if amount >= 16 { + if amount > 0 { + _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= 16 bits, value will be zero\n", + c.line.Filename, c.line.LineNo, amount) + } + + // Store zero + if c.sourceIsVar && c.sourceVarName == c.destVarName { + // Same variable: zero both bytes + asm := []string{ + "\tlda #0", + fmt.Sprintf("\tsta %s", c.destVarName), + fmt.Sprintf("\tsta %s+1", c.destVarName), + } + return asm, nil + } else { + // Different: load zero and store to both bytes + asm := []string{ + "\tlda #0", + fmt.Sprintf("\tsta %s", c.destVarName), + fmt.Sprintf("\tsta %s+1", c.destVarName), + } + return asm, nil + } + } + + // Shift 0 -> copy source + if amount == 0 { + return c.generateWordCopy(), nil + } + + // Constant shift 1-15 + return c.generateWordShiftConst(amount), nil + } + + // Variable shift amount + return c.generateWordShiftVar(ctx) +} + +// generateWordCopy copies source to destination (WORD) +func (c *ShiftLCommand) generateWordCopy() []string { + var asm []string + + if c.sourceIsVar && c.sourceVarName == c.destVarName { + // Same variable, no copy needed + return asm + } + + if c.sourceIsVar { + // Variable source if c.sourceVarKind == compiler.KindByte { // Byte -> Word (zero-extend) asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) @@ -266,83 +359,96 @@ func (c *ShiftLCommand) generateCopy() []string { asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } + } else { + // Literal source + lo := uint8(c.sourceValue & 0xFF) + hi := uint8((c.sourceValue >> 8) & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + if lo != hi { + asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) + } + asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } - + return asm } -// generateShift generates assembly to shift destination left by amount -func (c *ShiftLCommand) generateShift(ctx *compiler.CompilerContext) ([]string, error) { +// generateWordShiftConst generates WORD << constant (1-15) +func (c *ShiftLCommand) generateWordShiftConst(amount uint16) []string { var asm []string - - // Constant amount - if !c.amountIsVar { - amount := c.amountValue - if amount == 0 { - return asm, nil // No shift needed - } - - // Determine bit width - bitWidth := 8 - if c.destVarKind == compiler.KindWord { - bitWidth = 16 - } - - // Warn if shift amount >= bit width (but not for 0) - if amount >= uint16(bitWidth) { - _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= %d bits, value will be zero\n", - c.line.Filename, c.line.LineNo, amount, bitWidth) - } - - if amount >= uint16(bitWidth) { - // Shift all bits out -> zero - if c.destVarKind == compiler.KindByte { - asm = append(asm, "\tlda #0") - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Special case: shift >= 8 + if amount >= 8 && amount < 16 { + remaining := amount - 8 + + // Get source low byte + if c.sourceIsVar { + if c.sourceVarName == c.destVarName { + // Same variable: read from destination (already has value) + asm = append(asm, fmt.Sprintf("\tlda %s", c.destVarName)) } else { - asm = append(asm, "\tlda #0") - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) + // Different: read from source + asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) } - return asm, nil + } else { + // Literal + lo := uint8(c.sourceValue & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) } - - // Unroll shift loop - for i := uint16(0); i < amount; i++ { - if c.destVarKind == compiler.KindByte { - asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) - } else { - asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) - asm = append(asm, fmt.Sprintf("\trol %s+1", c.destVarName)) - } + + // Store to destination high byte + asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) + + // Zero low byte + asm = append(asm, "\tlda #0") + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Shift high byte remaining bits + for i := uint16(0); i < remaining; i++ { + asm = append(asm, fmt.Sprintf("\tasl %s+1", c.destVarName)) } - return asm, nil + + return asm } + + // Shift 1-7: normal copy and shift + copyAsm := c.generateWordCopy() + asm = append(asm, copyAsm...) + + // Apply shift + for i := uint16(0); i < amount; i++ { + asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) + asm = append(asm, fmt.Sprintf("\trol %s+1", c.destVarName)) + } + + return asm +} - // Variable amount +// generateWordShiftVar generates WORD << variable +func (c *ShiftLCommand) generateWordShiftVar(ctx *compiler.CompilerContext) ([]string, error) { + var asm []string + + // Copy source to destination if needed + copyAsm := c.generateWordCopy() + asm = append(asm, copyAsm...) + // Generate labels loopLabel := ctx.GeneralStack.Push() ctx.GeneralStack.Pop() doneLabel := ctx.GeneralStack.Push() ctx.GeneralStack.Pop() - - // Load amount into X + + // Variable shift loop asm = append(asm, fmt.Sprintf("\tldx %s", c.amountVarName)) - - // Check for zero amount asm = append(asm, fmt.Sprintf("\tbeq %s", doneLabel)) - - // Shift loop asm = append(asm, loopLabel) - if c.destVarKind == compiler.KindByte { - asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) - } else { - asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) - asm = append(asm, fmt.Sprintf("\trol %s+1", c.destVarName)) - } + asm = append(asm, fmt.Sprintf("\tasl %s", c.destVarName)) + asm = append(asm, fmt.Sprintf("\trol %s+1", c.destVarName)) asm = append(asm, "\tdex") asm = append(asm, fmt.Sprintf("\tbne %s", loopLabel)) - asm = append(asm, doneLabel) + return asm, nil -} \ No newline at end of file +} + diff --git a/internal/commands/shiftl_test.go b/internal/commands/shiftl_test.go index 492bbc3..8248069 100644 --- a/internal/commands/shiftl_test.go +++ b/internal/commands/shiftl_test.go @@ -518,6 +518,131 @@ func TestShiftLCommand_Generate(t *testing.T) { "_L2", }, }, + { + name: "WORD optimization - shift by 8", + setup: func(ctx *compiler.CompilerContext) *ShiftLCommand { + ctx.SymbolTable.AddVar("result", "", compiler.KindWord, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftLCommand{ + sourceIsVar: false, + sourceValue: 0x1234, + amountIsVar: false, + amountValue: 8, + destVarName: "result", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tlda #$34", + "\tsta result+1", + "\tlda #0", + "\tsta result", + }, + }, + { + name: "WORD optimization - shift by 12 (8 + 4)", + setup: func(ctx *compiler.CompilerContext) *ShiftLCommand { + ctx.SymbolTable.AddVar("result", "", compiler.KindWord, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftLCommand{ + sourceIsVar: false, + sourceValue: 0x00AB, + amountIsVar: false, + amountValue: 12, + destVarName: "result", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tlda #$ab", + "\tsta result+1", + "\tlda #0", + "\tsta result", + "\tasl result+1", + "\tasl result+1", + "\tasl result+1", + "\tasl result+1", + }, + }, + { + name: "BYTE to WORD conversion with shift", + setup: func(ctx *compiler.CompilerContext) *ShiftLCommand { + ctx.SymbolTable.AddVar("byteval", "", compiler.KindByte, 0x55, preproc.Line{Filename: "test.c65", LineNo: 1}) + ctx.SymbolTable.AddVar("result", "", compiler.KindWord, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftLCommand{ + sourceIsVar: true, + sourceVarName: "byteval", + sourceVarKind: compiler.KindByte, + amountIsVar: false, + amountValue: 4, + destVarName: "result", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tlda byteval", + "\tsta result", + "\tlda #0", + "\tsta result+1", + "\tasl result", + "\trol result+1", + "\tasl result", + "\trol result+1", + "\tasl result", + "\trol result+1", + "\tasl result", + "\trol result+1", + }, + }, + { + name: "WORD to BYTE conversion with shift (truncation)", + setup: func(ctx *compiler.CompilerContext) *ShiftLCommand { + ctx.SymbolTable.AddVar("wordval", "", compiler.KindWord, 0x1234, preproc.Line{Filename: "test.c65", LineNo: 1}) + ctx.SymbolTable.AddVar("result", "", compiler.KindByte, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftLCommand{ + sourceIsVar: true, + sourceVarName: "wordval", + sourceVarKind: compiler.KindWord, + amountIsVar: false, + amountValue: 4, + destVarName: "result", + destVarKind: compiler.KindByte, + } + }, + wantLines: []string{ + "\tlda wordval", + "\tsta result", + "\tasl result", + "\tasl result", + "\tasl result", + "\tasl result", + }, + }, + { + name: "WORD same source and destination", + setup: func(ctx *compiler.CompilerContext) *ShiftLCommand { + ctx.SymbolTable.AddVar("value", "", compiler.KindWord, 0x1234, preproc.Line{Filename: "test.c65", LineNo: 1}) + ctx.SymbolTable.AddVar("shift", "", compiler.KindByte, 2, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftLCommand{ + sourceIsVar: true, + sourceVarName: "value", + sourceVarKind: compiler.KindWord, + amountIsVar: true, + amountVarName: "shift", + amountVarKind: compiler.KindByte, + destVarName: "value", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tldx shift", + "\tbeq _L2", + "_L1", + "\tasl value", + "\trol value+1", + "\tdex", + "\tbne _L1", + "_L2", + }, + }, } for _, tt := range tests { diff --git a/internal/commands/shiftr.go b/internal/commands/shiftr.go index ffa0e1b..f366fc9 100644 --- a/internal/commands/shiftr.go +++ b/internal/commands/shiftr.go @@ -175,83 +175,320 @@ func (c *ShiftRCommand) Interpret(line preproc.Line, ctx *compiler.CompilerConte } func (c *ShiftRCommand) Generate(ctx *compiler.CompilerContext) ([]string, error) { + // Case 1: BYTE destination + if c.destVarKind == compiler.KindByte { + return c.generateByteShift(ctx) + } + + // Case 2: WORD destination + return c.generateWordShift(ctx) +} + +// generateByteShift handles BYTE >> amount +func (c *ShiftRCommand) generateByteShift(ctx *compiler.CompilerContext) ([]string, error) { + // Special case: WORD source, BYTE destination + if c.sourceVarKind == compiler.KindWord { + return c.generateWordToByteShift(ctx) + } + var asm []string - - // Check if shift amount >= bit width (result will be zero) - bitWidth := 8 - if c.destVarKind == compiler.KindWord { - bitWidth = 16 - } - - amountZero := false + + // Constant shift amount if !c.amountIsVar { - // Constant amount - if c.amountValue >= uint16(bitWidth) { - amountZero = true + amount := c.amountValue + + // Shift >= 8 bits -> zero + if amount >= 8 { + if amount > 0 { + _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= 8 bits, value will be zero\n", + c.line.Filename, c.line.LineNo, amount) + } + + // Store zero + if c.sourceIsVar && c.sourceVarName == c.destVarName { + // Same variable: just zero it + asm = append(asm, "\tlda #0") + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } else { + // Different: load zero and store + asm = append(asm, "\tlda #0") + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } + return asm, nil } + + // Shift 0 -> copy source + if amount == 0 { + return c.generateByteCopy(), nil + } + + // Constant shift 1-7 + return c.generateByteShiftConst(amount), nil } + + // Variable shift amount + return c.generateByteShiftVar(ctx) +} - if amountZero { - // Result is zero, just store zero - if c.destVarKind == compiler.KindByte { - asm = append(asm, "\tlda #0") +// generateByteCopy copies source to destination (BYTE) +func (c *ShiftRCommand) generateByteCopy() []string { + var asm []string + + if c.sourceIsVar { + if c.sourceVarName == c.destVarName { + // Same variable, no copy needed + return asm + } + // Different variable + if c.sourceVarKind == compiler.KindWord { + // WORD source, BYTE destination - just take low byte + asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } else { - asm = append(asm, "\tlda #0") + // BYTE source + asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } - return asm, nil - } - - // Step 1: Copy source to destination if needed - if c.sourceIsVar && c.sourceVarName == c.destVarName { - // Same variable, no copy needed } else { - copyAsm := c.generateCopy() - asm = append(asm, copyAsm...) + // Literal + val := uint8(c.sourceValue & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", val)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) } + + return asm +} - // Step 2: Apply shift - shiftAsm, err := c.generateShift(ctx) - if err != nil { - return nil, err +// generateByteShiftConst generates BYTE >> constant (1-7) +func (c *ShiftRCommand) generateByteShiftConst(amount uint16) []string { + var asm []string + + // Copy source to destination if needed + copyAsm := c.generateByteCopy() + asm = append(asm, copyAsm...) + + // Apply shift (right shift uses LSR) + for i := uint16(0); i < amount; i++ { + asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) } - asm = append(asm, shiftAsm...) + + return asm +} +// generateByteShiftVar generates BYTE >> variable +func (c *ShiftRCommand) generateByteShiftVar(ctx *compiler.CompilerContext) ([]string, error) { + var asm []string + + // Copy source to destination if needed + copyAsm := c.generateByteCopy() + asm = append(asm, copyAsm...) + + // Generate labels + loopLabel := ctx.GeneralStack.Push() + ctx.GeneralStack.Pop() + doneLabel := ctx.GeneralStack.Push() + ctx.GeneralStack.Pop() + + // Variable shift loop (right shift uses LSR) + asm = append(asm, fmt.Sprintf("\tldx %s", c.amountVarName)) + asm = append(asm, fmt.Sprintf("\tbeq %s", doneLabel)) + asm = append(asm, loopLabel) + asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) + asm = append(asm, "\tdex") + asm = append(asm, fmt.Sprintf("\tbne %s", loopLabel)) + asm = append(asm, doneLabel) + return asm, nil } -// generateCopy generates assembly to copy source to destination -func (c *ShiftRCommand) generateCopy() []string { +// generateWordToByteShift handles WORD >> amount -> BYTE +func (c *ShiftRCommand) generateWordToByteShift(ctx *compiler.CompilerContext) ([]string, error) { var asm []string - - // If source is literal, just load it - if !c.sourceIsVar { - lo := uint8(c.sourceValue & 0xFF) - hi := uint8((c.sourceValue >> 8) & 0xFF) - - if c.destVarKind == compiler.KindByte { - asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) + + // Constant shift amount + if !c.amountIsVar { + amount := c.amountValue + + // Shift >= 16 bits -> zero + if amount >= 16 { + if amount > 0 { + _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= 16 bits, value will be zero\n", + c.line.Filename, c.line.LineNo, amount) + } + + // Store zero + asm = append(asm, "\tlda #0") asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - } else { - asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - if lo != hi { + return asm, nil + } + + // Shift 0 -> just take low byte + if amount == 0 { + if c.sourceIsVar { + asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } else { + val := uint8(c.sourceValue & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", val)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + } + return asm, nil + } + + // For WORD -> BYTE right shift, we need to handle the full WORD shift + // then take the low byte. This is more complex than BYTE shift. + + // Shift 8-15: result depends only on high byte + if amount >= 8 && amount < 16 { + remaining := amount - 8 + + // Get high byte + if c.sourceIsVar { + asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) + } else { + hi := uint8((c.sourceValue >> 8) & 0xFF) asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) } - asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) + + // Store to destination + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Shift remaining bits + for i := uint16(0); i < remaining; i++ { + asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) + } + + return asm, nil } - return asm + + // Shift 1-7: need full WORD shift, then take low byte + // We'll use the destination as a temporary for the low byte + // and handle high byte in A register + + if c.sourceIsVar { + // Load low byte to destination + asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Load high byte to A + asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) + } else { + // Literal source + lo := uint8(c.sourceValue & 0xFF) + hi := uint8((c.sourceValue >> 8) & 0xFF) + + asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) + } + + // Apply shifts: lsr A (high byte), ror destination (low byte) + for i := uint16(0); i < amount; i++ { + asm = append(asm, "\tlsr") + asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) + } + + return asm, nil } - - // Source is variable - if c.destVarKind == compiler.KindByte { - // Destination is byte + + // Variable shift amount - use general approach with loop + // We need X for shift count, A for high byte, destination for low byte + + // Generate labels + loopLabel := ctx.GeneralStack.Push() + ctx.GeneralStack.Pop() + doneLabel := ctx.GeneralStack.Push() + ctx.GeneralStack.Pop() + + if c.sourceIsVar { + // Load low byte to destination asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Load high byte to A + asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) } else { - // Destination is word + // Literal source + lo := uint8(c.sourceValue & 0xFF) + hi := uint8((c.sourceValue >> 8) & 0xFF) + + asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) + } + + // Load shift amount to X + asm = append(asm, fmt.Sprintf("\tldx %s", c.amountVarName)) + asm = append(asm, fmt.Sprintf("\tbeq %s", doneLabel)) + + // Shift loop + asm = append(asm, loopLabel) + asm = append(asm, "\tlsr") + asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) + asm = append(asm, "\tdex") + asm = append(asm, fmt.Sprintf("\tbne %s", loopLabel)) + asm = append(asm, doneLabel) + + return asm, nil +} + +// generateWordShift handles WORD >> amount +func (c *ShiftRCommand) generateWordShift(ctx *compiler.CompilerContext) ([]string, error) { + // Constant shift amount + if !c.amountIsVar { + amount := c.amountValue + + // Shift >= 16 bits -> zero + if amount >= 16 { + if amount > 0 { + _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= 16 bits, value will be zero\n", + c.line.Filename, c.line.LineNo, amount) + } + + // Store zero + if c.sourceIsVar && c.sourceVarName == c.destVarName { + // Same variable: zero both bytes + asm := []string{ + "\tlda #0", + fmt.Sprintf("\tsta %s", c.destVarName), + fmt.Sprintf("\tsta %s+1", c.destVarName), + } + return asm, nil + } else { + // Different: load zero and store to both bytes + asm := []string{ + "\tlda #0", + fmt.Sprintf("\tsta %s", c.destVarName), + fmt.Sprintf("\tsta %s+1", c.destVarName), + } + return asm, nil + } + } + + // Shift 0 -> copy source + if amount == 0 { + return c.generateWordCopy(), nil + } + + // Constant shift 1-15 + return c.generateWordShiftConst(amount), nil + } + + // Variable shift amount + return c.generateWordShiftVar(ctx) +} + +// generateWordCopy copies source to destination (WORD) +func (c *ShiftRCommand) generateWordCopy() []string { + var asm []string + + if c.sourceIsVar && c.sourceVarName == c.destVarName { + // Same variable, no copy needed + return asm + } + + if c.sourceIsVar { + // Variable source if c.sourceVarKind == compiler.KindByte { // Byte -> Word (zero-extend) asm = append(asm, fmt.Sprintf("\tlda %s", c.sourceVarName)) @@ -265,84 +502,96 @@ func (c *ShiftRCommand) generateCopy() []string { asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } + } else { + // Literal source + lo := uint8(c.sourceValue & 0xFF) + hi := uint8((c.sourceValue >> 8) & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", lo)) + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + if lo != hi { + asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) + } + asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) } - + return asm } -// generateShift generates assembly to shift destination right by amount -func (c *ShiftRCommand) generateShift(ctx *compiler.CompilerContext) ([]string, error) { +// generateWordShiftConst generates WORD >> constant (1-15) +func (c *ShiftRCommand) generateWordShiftConst(amount uint16) []string { var asm []string - - // Constant amount - if !c.amountIsVar { - amount := c.amountValue - if amount == 0 { - return asm, nil // No shift needed - } - - // Determine bit width - bitWidth := 8 - if c.destVarKind == compiler.KindWord { - bitWidth = 16 - } - - // Warn if shift amount >= bit width (but not for 0) - if amount >= uint16(bitWidth) { - _, _ = fmt.Fprintf(os.Stderr, "%s:%d: warning: shift amount %d >= %d bits, value will be zero\n", - c.line.Filename, c.line.LineNo, amount, bitWidth) - } - - if amount >= uint16(bitWidth) { - // Shift all bits out -> zero - if c.destVarKind == compiler.KindByte { - asm = append(asm, "\tlda #0") - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Special case: shift >= 8 + if amount >= 8 && amount < 16 { + remaining := amount - 8 + + // Get source high byte + if c.sourceIsVar { + if c.sourceVarName == c.destVarName { + // Same variable: read from destination (already has value) + asm = append(asm, fmt.Sprintf("\tlda %s+1", c.destVarName)) } else { - asm = append(asm, "\tlda #0") - asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) - asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) + // Different: read from source + asm = append(asm, fmt.Sprintf("\tlda %s+1", c.sourceVarName)) } - return asm, nil + } else { + // Literal + hi := uint8((c.sourceValue >> 8) & 0xFF) + asm = append(asm, fmt.Sprintf("\tlda #$%02x", hi)) } - - // Unroll shift loop - for i := uint16(0); i < amount; i++ { - if c.destVarKind == compiler.KindByte { - asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) - } else { - asm = append(asm, fmt.Sprintf("\tlsr %s+1", c.destVarName)) - asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) - } + + // Store to destination low byte + asm = append(asm, fmt.Sprintf("\tsta %s", c.destVarName)) + + // Zero high byte + asm = append(asm, "\tlda #0") + asm = append(asm, fmt.Sprintf("\tsta %s+1", c.destVarName)) + + // Shift low byte remaining bits (right shift) + for i := uint16(0); i < remaining; i++ { + asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) } - return asm, nil + + return asm } + + // Shift 1-7: normal copy and shift + copyAsm := c.generateWordCopy() + asm = append(asm, copyAsm...) + + // Apply shift (right shift uses LSR high byte, ROR low byte) + for i := uint16(0); i < amount; i++ { + asm = append(asm, fmt.Sprintf("\tlsr %s+1", c.destVarName)) + asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) + } + + return asm +} - // Variable amount +// generateWordShiftVar generates WORD >> variable +func (c *ShiftRCommand) generateWordShiftVar(ctx *compiler.CompilerContext) ([]string, error) { + var asm []string + + // Copy source to destination if needed + copyAsm := c.generateWordCopy() + asm = append(asm, copyAsm...) + // Generate labels loopLabel := ctx.GeneralStack.Push() ctx.GeneralStack.Pop() doneLabel := ctx.GeneralStack.Push() ctx.GeneralStack.Pop() - - - // Load amount into X + + // Variable shift loop (right shift uses LSR high byte, ROR low byte) asm = append(asm, fmt.Sprintf("\tldx %s", c.amountVarName)) - - // Check for zero amount asm = append(asm, fmt.Sprintf("\tbeq %s", doneLabel)) - - // Shift loop asm = append(asm, loopLabel) - if c.destVarKind == compiler.KindByte { - asm = append(asm, fmt.Sprintf("\tlsr %s", c.destVarName)) - } else { - asm = append(asm, fmt.Sprintf("\tlsr %s+1", c.destVarName)) - asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) - } + asm = append(asm, fmt.Sprintf("\tlsr %s+1", c.destVarName)) + asm = append(asm, fmt.Sprintf("\tror %s", c.destVarName)) asm = append(asm, "\tdex") asm = append(asm, fmt.Sprintf("\tbne %s", loopLabel)) - asm = append(asm, doneLabel) + return asm, nil -} \ No newline at end of file +} + diff --git a/internal/commands/shiftr_test.go b/internal/commands/shiftr_test.go index 8d9ddb7..15a59d0 100644 --- a/internal/commands/shiftr_test.go +++ b/internal/commands/shiftr_test.go @@ -518,6 +518,136 @@ func TestShiftRCommand_Generate(t *testing.T) { "_L2", }, }, + { + name: "WORD optimization - shift by 8", + setup: func(ctx *compiler.CompilerContext) *ShiftRCommand { + ctx.SymbolTable.AddVar("result", "", compiler.KindWord, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftRCommand{ + sourceIsVar: false, + sourceValue: 0x1234, + amountIsVar: false, + amountValue: 8, + destVarName: "result", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tlda #$12", + "\tsta result", + "\tlda #0", + "\tsta result+1", + }, + }, + { + name: "WORD optimization - shift by 12 (8 + 4)", + setup: func(ctx *compiler.CompilerContext) *ShiftRCommand { + ctx.SymbolTable.AddVar("result", "", compiler.KindWord, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftRCommand{ + sourceIsVar: false, + sourceValue: 0xAB00, + amountIsVar: false, + amountValue: 12, + destVarName: "result", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tlda #$ab", + "\tsta result", + "\tlda #0", + "\tsta result+1", + "\tlsr result", + "\tlsr result", + "\tlsr result", + "\tlsr result", + }, + }, + { + name: "BYTE to WORD conversion with shift", + setup: func(ctx *compiler.CompilerContext) *ShiftRCommand { + ctx.SymbolTable.AddVar("byteval", "", compiler.KindByte, 0x55, preproc.Line{Filename: "test.c65", LineNo: 1}) + ctx.SymbolTable.AddVar("result", "", compiler.KindWord, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftRCommand{ + sourceIsVar: true, + sourceVarName: "byteval", + sourceVarKind: compiler.KindByte, + amountIsVar: false, + amountValue: 4, + destVarName: "result", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tlda byteval", + "\tsta result", + "\tlda #0", + "\tsta result+1", + "\tlsr result+1", + "\tror result", + "\tlsr result+1", + "\tror result", + "\tlsr result+1", + "\tror result", + "\tlsr result+1", + "\tror result", + }, + }, + { + name: "WORD to BYTE conversion with shift (truncation)", + setup: func(ctx *compiler.CompilerContext) *ShiftRCommand { + ctx.SymbolTable.AddVar("wordval", "", compiler.KindWord, 0x1234, preproc.Line{Filename: "test.c65", LineNo: 1}) + ctx.SymbolTable.AddVar("result", "", compiler.KindByte, 0, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftRCommand{ + sourceIsVar: true, + sourceVarName: "wordval", + sourceVarKind: compiler.KindWord, + amountIsVar: false, + amountValue: 4, + destVarName: "result", + destVarKind: compiler.KindByte, + } + }, + wantLines: []string{ + "\tlda wordval", + "\tsta result", + "\tlda wordval+1", + "\tlsr", + "\tror result", + "\tlsr", + "\tror result", + "\tlsr", + "\tror result", + "\tlsr", + "\tror result", + }, + }, + { + name: "WORD same source and destination", + setup: func(ctx *compiler.CompilerContext) *ShiftRCommand { + ctx.SymbolTable.AddVar("value", "", compiler.KindWord, 0x1234, preproc.Line{Filename: "test.c65", LineNo: 1}) + ctx.SymbolTable.AddVar("shift", "", compiler.KindByte, 2, preproc.Line{Filename: "test.c65", LineNo: 1}) + return &ShiftRCommand{ + sourceIsVar: true, + sourceVarName: "value", + sourceVarKind: compiler.KindWord, + amountIsVar: true, + amountVarName: "shift", + amountVarKind: compiler.KindByte, + destVarName: "value", + destVarKind: compiler.KindWord, + } + }, + wantLines: []string{ + "\tldx shift", + "\tbeq _L2", + "_L1", + "\tlsr value+1", + "\tror value", + "\tdex", + "\tbne _L1", + "_L2", + }, + }, } for _, tt := range tests {