package commands import ( "strings" "testing" "c65gm/internal/compiler" "c65gm/internal/preproc" ) func TestIfCommand_WillHandle(t *testing.T) { cmd := &IfCommand{} tests := []struct { name string line string want bool }{ {"basic IF", "IF a = b", true}, {"IF with THEN", "IF a = b THEN", true}, {"not IF", "LET a = b", false}, {"empty", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { line := preproc.Line{Text: tt.line, Kind: preproc.Source} got := cmd.WillHandle(line) if got != tt.want { t.Errorf("WillHandle() = %v, want %v", got, tt.want) } }) } } func TestIfCommand_Equal_ShortJump(t *testing.T) { tests := []struct { name string line string setupVars func(*compiler.SymbolTable) wantAsm []string }{ { name: "byte var == byte var", line: "IF a = b", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 10) st.AddVar("b", "", compiler.KindByte, 20) }, wantAsm: []string{ "\tlda a", "\tcmp b", "\tbne _I1", }, }, { name: "byte var == byte literal", line: "IF a = 100", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 0) }, wantAsm: []string{ "\tlda a", "\tcmp #$64", "\tbne _I1", }, }, { name: "byte literal == byte var", line: "IF 100 = a", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 0) }, wantAsm: []string{ "\tlda #$64", "\tcmp a", "\tbne _I1", }, }, { name: "word var == word var", line: "IF x = y", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0x1234) st.AddVar("y", "", compiler.KindWord, 0x5678) }, wantAsm: []string{ "\tlda x", "\tcmp y", "\tbne _I1", "\tlda x+1", "\tcmp y+1", "\tbne _I1", }, }, { name: "word var == word literal", line: "IF x == $1234", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0) }, wantAsm: []string{ "\tlda x", "\tcmp #$34", "\tbne _I1", "\tlda x+1", "\tcmp #$12", "\tbne _I1", }, }, { name: "word literal == word var", line: "IF $1234 = x", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0) }, wantAsm: []string{ "\tlda #$34", "\tcmp x", "\tbne _I1", "\tlda #$12", "\tcmp x+1", "\tbne _I1", }, }, { name: "byte var == word var (mixed)", line: "IF b = x", setupVars: func(st *compiler.SymbolTable) { st.AddVar("b", "", compiler.KindByte, 50) st.AddVar("x", "", compiler.KindWord, 100) }, wantAsm: []string{ "\tlda x+1", "\tcmp #0", "\tbne _I1", "\tlda b", "\tcmp x", "\tbne _I1", }, }, { name: "word var == byte var (mixed)", line: "IF x = b", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 100) st.AddVar("b", "", compiler.KindByte, 50) }, wantAsm: []string{ "\tlda x+1", "\tcmp #0", "\tbne _I1", "\tlda b", "\tcmp x", "\tbne _I1", }, }, { name: "constant folding - equal", line: "IF 100 = 100", setupVars: func(st *compiler.SymbolTable) {}, wantAsm: []string{}, }, { name: "constant folding - not equal", line: "IF 100 = 200", setupVars: func(st *compiler.SymbolTable) {}, wantAsm: []string{ "\tjmp _I1", }, }, { name: "with THEN keyword", line: "IF a = b THEN", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 0) st.AddVar("b", "", compiler.KindByte, 0) }, wantAsm: []string{ "\tlda a", "\tcmp b", "\tbne _I1", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := compiler.NewCompilerContext(preproc.NewPragma()) tt.setupVars(ctx.SymbolTable) cmd := &IfCommand{} line := preproc.Line{Text: tt.line, Kind: preproc.Source, PragmaSetIndex: 0} err := cmd.Interpret(line, ctx) if err != nil { t.Fatalf("Interpret() error = %v", err) } asm, err := cmd.Generate(ctx) if err != nil { t.Fatalf("Generate() error = %v", err) } if !equalAsmIf(asm, tt.wantAsm) { t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s", strings.Join(asm, "\n"), strings.Join(tt.wantAsm, "\n")) } }) } } func TestIfCommand_Equal_LongJump(t *testing.T) { tests := []struct { name string line string setupVars func(*compiler.SymbolTable) wantAsm []string }{ { name: "byte var == byte var (long jump)", line: "IF a = b", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 10) st.AddVar("b", "", compiler.KindByte, 20) }, wantAsm: []string{ "\tlda a", "\tcmp b", "\tbeq _L1", "\tjmp _I1", "_L1", }, }, { name: "byte var == byte literal (long jump)", line: "IF a = 100", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 0) }, wantAsm: []string{ "\tlda a", "\tcmp #$64", "\tbeq _L1", "\tjmp _I1", "_L1", }, }, { name: "word var == word var (long jump)", line: "IF x = y", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0x1234) st.AddVar("y", "", compiler.KindWord, 0x5678) }, wantAsm: []string{ "\tlda x", "\tcmp y", "\tbne _L2", "\tlda x+1", "\tcmp y+1", "\tbeq _L1", "_L2", "\tjmp _I1", "_L1", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pragma := preproc.NewPragma() pragma.AddPragma("_P_USE_LONG_JUMP", "1") ctx := compiler.NewCompilerContext(pragma) tt.setupVars(ctx.SymbolTable) cmd := &IfCommand{} line := preproc.Line{ Text: tt.line, Kind: preproc.Source, PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), } err := cmd.Interpret(line, ctx) if err != nil { t.Fatalf("Interpret() error = %v", err) } asm, err := cmd.Generate(ctx) if err != nil { t.Fatalf("Generate() error = %v", err) } if !equalAsmIf(asm, tt.wantAsm) { t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s", strings.Join(asm, "\n"), strings.Join(tt.wantAsm, "\n")) } }) } } func TestIfCommand_NotEqual_ShortJump(t *testing.T) { tests := []struct { name string line string setupVars func(*compiler.SymbolTable) wantAsm []string }{ { name: "byte var != byte var", line: "IF a <> b", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 10) st.AddVar("b", "", compiler.KindByte, 20) }, wantAsm: []string{ "\tlda a", "\tcmp b", "\tbeq _I1", }, }, { name: "byte var != byte literal", line: "IF a != 100", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 0) }, wantAsm: []string{ "\tlda a", "\tcmp #$64", "\tbeq _I1", }, }, { name: "word var != word var", line: "IF x <> y", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0x1234) st.AddVar("y", "", compiler.KindWord, 0x5678) }, wantAsm: []string{ "\tlda x", "\tcmp y", "\tbne _L1", "\tlda x+1", "\tcmp y+1", "\tbne _L1", "\tjmp _I1", "_L1", }, }, { name: "word var != word literal", line: "IF x != $1234", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0) }, wantAsm: []string{ "\tlda x", "\tcmp #$34", "\tbne _L1", "\tlda x+1", "\tcmp #$12", "\tbne _L1", "\tjmp _I1", "_L1", }, }, { name: "byte var != word var (mixed)", line: "IF b <> x", setupVars: func(st *compiler.SymbolTable) { st.AddVar("b", "", compiler.KindByte, 50) st.AddVar("x", "", compiler.KindWord, 100) }, wantAsm: []string{ "\tlda x+1", "\tcmp #0", "\tbne _L1", "\tlda b", "\tcmp x", "\tbeq _I1", "_L1", }, }, { name: "constant folding - not equal", line: "IF 100 <> 200", setupVars: func(st *compiler.SymbolTable) {}, wantAsm: []string{}, }, { name: "constant folding - equal", line: "IF 100 != 100", setupVars: func(st *compiler.SymbolTable) {}, wantAsm: []string{ "\tjmp _I1", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := compiler.NewCompilerContext(preproc.NewPragma()) tt.setupVars(ctx.SymbolTable) cmd := &IfCommand{} line := preproc.Line{Text: tt.line, Kind: preproc.Source, PragmaSetIndex: 0} err := cmd.Interpret(line, ctx) if err != nil { t.Fatalf("Interpret() error = %v", err) } asm, err := cmd.Generate(ctx) if err != nil { t.Fatalf("Generate() error = %v", err) } if !equalAsmIf(asm, tt.wantAsm) { t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s", strings.Join(asm, "\n"), strings.Join(tt.wantAsm, "\n")) } }) } } func TestIfCommand_NotEqual_LongJump(t *testing.T) { tests := []struct { name string line string setupVars func(*compiler.SymbolTable) wantAsm []string }{ { name: "byte var != byte var (long jump)", line: "IF a <> b", setupVars: func(st *compiler.SymbolTable) { st.AddVar("a", "", compiler.KindByte, 10) st.AddVar("b", "", compiler.KindByte, 20) }, wantAsm: []string{ "\tlda a", "\tcmp b", "\tbne _L1", "\tjmp _I1", "_L1", }, }, { name: "word var != word var (long jump)", line: "IF x <> y", setupVars: func(st *compiler.SymbolTable) { st.AddVar("x", "", compiler.KindWord, 0x1234) st.AddVar("y", "", compiler.KindWord, 0x5678) }, wantAsm: []string{ "\tlda x", "\tcmp y", "\tbne _L1", "\tlda x+1", "\tcmp y+1", "\tbne _L1", "\tjmp _I1", "_L1", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pragma := preproc.NewPragma() pragma.AddPragma("_P_USE_LONG_JUMP", "1") ctx := compiler.NewCompilerContext(pragma) tt.setupVars(ctx.SymbolTable) cmd := &IfCommand{} line := preproc.Line{ Text: tt.line, Kind: preproc.Source, PragmaSetIndex: pragma.GetCurrentPragmaSetIndex(), } err := cmd.Interpret(line, ctx) if err != nil { t.Fatalf("Interpret() error = %v", err) } asm, err := cmd.Generate(ctx) if err != nil { t.Fatalf("Generate() error = %v", err) } if !equalAsmIf(asm, tt.wantAsm) { t.Errorf("Generate() mismatch\ngot:\n%s\nwant:\n%s", strings.Join(asm, "\n"), strings.Join(tt.wantAsm, "\n")) } }) } } func TestIfCommand_Errors(t *testing.T) { tests := []struct { name string line string setupVars func(*compiler.SymbolTable) wantErr string }{ { name: "wrong param count", line: "IF a = b = c", setupVars: func(st *compiler.SymbolTable) {}, wantErr: "wrong number of parameters", }, { name: "unsupported operator", line: "IF a > b", setupVars: func(st *compiler.SymbolTable) {}, wantErr: "unsupported operator", }, { name: "unknown variable", line: "IF unknown = b", setupVars: func(st *compiler.SymbolTable) { st.AddVar("b", "", compiler.KindByte, 0) }, wantErr: "not a valid variable or expression", }, { name: "invalid THEN", line: "IF a = b NOT", setupVars: func(st *compiler.SymbolTable) {}, wantErr: "must be 'THEN'", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := compiler.NewCompilerContext(preproc.NewPragma()) tt.setupVars(ctx.SymbolTable) cmd := &IfCommand{} line := preproc.Line{Text: tt.line, Kind: preproc.Source, PragmaSetIndex: 0} err := cmd.Interpret(line, ctx) if err == nil { t.Fatal("expected error, got nil") } if !strings.Contains(err.Error(), tt.wantErr) { t.Errorf("error = %q, want substring %q", err.Error(), tt.wantErr) } }) } } // equalAsmIf compares two assembly slices for equality func equalAsmIf(a, b []string) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true }