c65gm/internal/commands/if_test.go

571 lines
12 KiB
Go

package commands
import (
"strings"
"testing"
"c65gm/internal/compiler"
"c65gm/internal/preproc"
)
func TestIfCommand_WillHandle(t *testing.T) {
cmd := &IfCommand{}
tests := []struct {
name string
line string
want bool
}{
{"basic IF", "IF a = b", true},
{"IF with THEN", "IF a = b THEN", true},
{"not IF", "LET a = b", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
line := preproc.Line{Text: tt.line, Kind: preproc.Source}
got := cmd.WillHandle(line)
if got != tt.want {
t.Errorf("WillHandle() = %v, want %v", got, tt.want)
}
})
}
}
func TestIfCommand_Equal_ShortJump(t *testing.T) {
tests := []struct {
name string
line string
setupVars func(*compiler.SymbolTable)
wantAsm []string
}{
{
name: "byte var == byte var",
line: "IF a = b",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 10)
st.AddVar("b", "", compiler.KindByte, 20)
},
wantAsm: []string{
"\tlda a",
"\tcmp b",
"\tbne _I1",
},
},
{
name: "byte var == byte literal",
line: "IF a = 100",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 0)
},
wantAsm: []string{
"\tlda a",
"\tcmp #$64",
"\tbne _I1",
},
},
{
name: "byte literal == byte var",
line: "IF 100 = a",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 0)
},
wantAsm: []string{
"\tlda #$64",
"\tcmp a",
"\tbne _I1",
},
},
{
name: "word var == word var",
line: "IF x = y",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0x1234)
st.AddVar("y", "", compiler.KindWord, 0x5678)
},
wantAsm: []string{
"\tlda x",
"\tcmp y",
"\tbne _I1",
"\tlda x+1",
"\tcmp y+1",
"\tbne _I1",
},
},
{
name: "word var == word literal",
line: "IF x == $1234",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0)
},
wantAsm: []string{
"\tlda x",
"\tcmp #$34",
"\tbne _I1",
"\tlda x+1",
"\tcmp #$12",
"\tbne _I1",
},
},
{
name: "word literal == word var",
line: "IF $1234 = x",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0)
},
wantAsm: []string{
"\tlda #$34",
"\tcmp x",
"\tbne _I1",
"\tlda #$12",
"\tcmp x+1",
"\tbne _I1",
},
},
{
name: "byte var == word var (mixed)",
line: "IF b = x",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("b", "", compiler.KindByte, 50)
st.AddVar("x", "", compiler.KindWord, 100)
},
wantAsm: []string{
"\tlda x+1",
"\tcmp #0",
"\tbne _I1",
"\tlda b",
"\tcmp x",
"\tbne _I1",
},
},
{
name: "word var == byte var (mixed)",
line: "IF x = b",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 100)
st.AddVar("b", "", compiler.KindByte, 50)
},
wantAsm: []string{
"\tlda x+1",
"\tcmp #0",
"\tbne _I1",
"\tlda b",
"\tcmp x",
"\tbne _I1",
},
},
{
name: "constant folding - equal",
line: "IF 100 = 100",
setupVars: func(st *compiler.SymbolTable) {},
wantAsm: []string{},
},
{
name: "constant folding - not equal",
line: "IF 100 = 200",
setupVars: func(st *compiler.SymbolTable) {},
wantAsm: []string{
"\tjmp _I1",
},
},
{
name: "with THEN keyword",
line: "IF a = b THEN",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 0)
st.AddVar("b", "", compiler.KindByte, 0)
},
wantAsm: []string{
"\tlda a",
"\tcmp b",
"\tbne _I1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := compiler.NewCompilerContext(preproc.NewPragma())
tt.setupVars(ctx.SymbolTable)
cmd := &IfCommand{}
line := preproc.Line{Text: tt.line, Kind: preproc.Source, PragmaSetIndex: 0}
err := cmd.Interpret(line, ctx)
if err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
if !equalAsmIf(asm, tt.wantAsm) {
t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s",
strings.Join(asm, "\n"),
strings.Join(tt.wantAsm, "\n"))
}
})
}
}
func TestIfCommand_Equal_LongJump(t *testing.T) {
tests := []struct {
name string
line string
setupVars func(*compiler.SymbolTable)
wantAsm []string
}{
{
name: "byte var == byte var (long jump)",
line: "IF a = b",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 10)
st.AddVar("b", "", compiler.KindByte, 20)
},
wantAsm: []string{
"\tlda a",
"\tcmp b",
"\tbeq _L1",
"\tjmp _I1",
"_L1",
},
},
{
name: "byte var == byte literal (long jump)",
line: "IF a = 100",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 0)
},
wantAsm: []string{
"\tlda a",
"\tcmp #$64",
"\tbeq _L1",
"\tjmp _I1",
"_L1",
},
},
{
name: "word var == word var (long jump)",
line: "IF x = y",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0x1234)
st.AddVar("y", "", compiler.KindWord, 0x5678)
},
wantAsm: []string{
"\tlda x",
"\tcmp y",
"\tbne _L2",
"\tlda x+1",
"\tcmp y+1",
"\tbeq _L1",
"_L2",
"\tjmp _I1",
"_L1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pragma := preproc.NewPragma()
pragma.AddPragma("_P_USE_LONG_JUMP", "1")
ctx := compiler.NewCompilerContext(pragma)
tt.setupVars(ctx.SymbolTable)
cmd := &IfCommand{}
line := preproc.Line{
Text: tt.line,
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
err := cmd.Interpret(line, ctx)
if err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
if !equalAsmIf(asm, tt.wantAsm) {
t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s",
strings.Join(asm, "\n"),
strings.Join(tt.wantAsm, "\n"))
}
})
}
}
func TestIfCommand_NotEqual_ShortJump(t *testing.T) {
tests := []struct {
name string
line string
setupVars func(*compiler.SymbolTable)
wantAsm []string
}{
{
name: "byte var != byte var",
line: "IF a <> b",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 10)
st.AddVar("b", "", compiler.KindByte, 20)
},
wantAsm: []string{
"\tlda a",
"\tcmp b",
"\tbeq _I1",
},
},
{
name: "byte var != byte literal",
line: "IF a != 100",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 0)
},
wantAsm: []string{
"\tlda a",
"\tcmp #$64",
"\tbeq _I1",
},
},
{
name: "word var != word var",
line: "IF x <> y",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0x1234)
st.AddVar("y", "", compiler.KindWord, 0x5678)
},
wantAsm: []string{
"\tlda x",
"\tcmp y",
"\tbne _L1",
"\tlda x+1",
"\tcmp y+1",
"\tbne _L1",
"\tjmp _I1",
"_L1",
},
},
{
name: "word var != word literal",
line: "IF x != $1234",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0)
},
wantAsm: []string{
"\tlda x",
"\tcmp #$34",
"\tbne _L1",
"\tlda x+1",
"\tcmp #$12",
"\tbne _L1",
"\tjmp _I1",
"_L1",
},
},
{
name: "byte var != word var (mixed)",
line: "IF b <> x",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("b", "", compiler.KindByte, 50)
st.AddVar("x", "", compiler.KindWord, 100)
},
wantAsm: []string{
"\tlda x+1",
"\tcmp #0",
"\tbne _L1",
"\tlda b",
"\tcmp x",
"\tbeq _I1",
"_L1",
},
},
{
name: "constant folding - not equal",
line: "IF 100 <> 200",
setupVars: func(st *compiler.SymbolTable) {},
wantAsm: []string{},
},
{
name: "constant folding - equal",
line: "IF 100 != 100",
setupVars: func(st *compiler.SymbolTable) {},
wantAsm: []string{
"\tjmp _I1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := compiler.NewCompilerContext(preproc.NewPragma())
tt.setupVars(ctx.SymbolTable)
cmd := &IfCommand{}
line := preproc.Line{Text: tt.line, Kind: preproc.Source, PragmaSetIndex: 0}
err := cmd.Interpret(line, ctx)
if err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
if !equalAsmIf(asm, tt.wantAsm) {
t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s",
strings.Join(asm, "\n"),
strings.Join(tt.wantAsm, "\n"))
}
})
}
}
func TestIfCommand_NotEqual_LongJump(t *testing.T) {
tests := []struct {
name string
line string
setupVars func(*compiler.SymbolTable)
wantAsm []string
}{
{
name: "byte var != byte var (long jump)",
line: "IF a <> b",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("a", "", compiler.KindByte, 10)
st.AddVar("b", "", compiler.KindByte, 20)
},
wantAsm: []string{
"\tlda a",
"\tcmp b",
"\tbne _L1",
"\tjmp _I1",
"_L1",
},
},
{
name: "word var != word var (long jump)",
line: "IF x <> y",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("x", "", compiler.KindWord, 0x1234)
st.AddVar("y", "", compiler.KindWord, 0x5678)
},
wantAsm: []string{
"\tlda x",
"\tcmp y",
"\tbne _L1",
"\tlda x+1",
"\tcmp y+1",
"\tbne _L1",
"\tjmp _I1",
"_L1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pragma := preproc.NewPragma()
pragma.AddPragma("_P_USE_LONG_JUMP", "1")
ctx := compiler.NewCompilerContext(pragma)
tt.setupVars(ctx.SymbolTable)
cmd := &IfCommand{}
line := preproc.Line{
Text: tt.line,
Kind: preproc.Source,
PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(),
}
err := cmd.Interpret(line, ctx)
if err != nil {
t.Fatalf("Interpret() error = %v", err)
}
asm, err := cmd.Generate(ctx)
if err != nil {
t.Fatalf("Generate() error = %v", err)
}
if !equalAsmIf(asm, tt.wantAsm) {
t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s",
strings.Join(asm, "\n"),
strings.Join(tt.wantAsm, "\n"))
}
})
}
}
func TestIfCommand_Errors(t *testing.T) {
tests := []struct {
name string
line string
setupVars func(*compiler.SymbolTable)
wantErr string
}{
{
name: "wrong param count",
line: "IF a = b = c",
setupVars: func(st *compiler.SymbolTable) {},
wantErr: "wrong number of parameters",
},
{
name: "unsupported operator",
line: "IF a > b",
setupVars: func(st *compiler.SymbolTable) {},
wantErr: "unsupported operator",
},
{
name: "unknown variable",
line: "IF unknown = b",
setupVars: func(st *compiler.SymbolTable) {
st.AddVar("b", "", compiler.KindByte, 0)
},
wantErr: "not a valid variable or expression",
},
{
name: "invalid THEN",
line: "IF a = b NOT",
setupVars: func(st *compiler.SymbolTable) {},
wantErr: "must be 'THEN'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := compiler.NewCompilerContext(preproc.NewPragma())
tt.setupVars(ctx.SymbolTable)
cmd := &IfCommand{}
line := preproc.Line{Text: tt.line, Kind: preproc.Source, PragmaSetIndex: 0}
err := cmd.Interpret(line, ctx)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want substring %q", err.Error(), tt.wantErr)
}
})
}
}
// equalAsmIf compares two assembly slices for equality
func equalAsmIf(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}