yaze 0.3.2
Link to the Past ROM Editor
 
Loading...
Searching...
No Matches
conversation_test.cc
Go to the documentation of this file.
1#include "cli/handlers/commands.h"
2#include "app/rom.h"
3#include "app/core/project.h"
5
6#include "absl/flags/declare.h"
7#include "absl/flags/flag.h"
8#include <fstream>
9#include <iostream>
10#include <string>
11#include <vector>
12
13#include "absl/status/status.h"
14#include "absl/strings/str_cat.h"
17#include "nlohmann/json.hpp"
18
19ABSL_DECLARE_FLAG(std::string, rom);
20ABSL_DECLARE_FLAG(bool, mock_rom);
21
22namespace yaze {
23namespace cli {
24namespace agent {
25
26namespace {
27
28absl::Status LoadRomForAgent(Rom& rom) {
29 if (rom.is_loaded()) {
30 return ::absl::OkStatus();
31 }
32
33 // Check if mock ROM mode is enabled
34 bool use_mock = ::absl::GetFlag(FLAGS_mock_rom);
35 if (use_mock) {
36 // Initialize mock ROM with embedded labels
37 auto status = InitializeMockRom(rom);
38 if (!status.ok()) {
39 return status;
40 }
41 std::cout << "โœ… Mock ROM initialized with embedded Zelda3 labels\n";
42 return ::absl::OkStatus();
43 }
44
45 // Otherwise load from file
46 std::string rom_path = ::absl::GetFlag(FLAGS_rom);
47 if (rom_path.empty()) {
48 return ::absl::InvalidArgumentError(
49 "No ROM loaded. Pass --rom=<path> or use --mock-rom for testing.");
50 }
51
52 auto status = rom.LoadFromFile(rom_path);
53 if (!status.ok()) {
54 return ::absl::FailedPreconditionError(
55 ::absl::StrCat("Failed to load ROM from '", rom_path,
56 "': ", status.message()));
57 }
58
59 return ::absl::OkStatus();
60}
61
63 std::string name;
64 std::string description;
65 std::vector<std::string> user_prompts;
66 std::vector<std::string> expected_keywords; // Keywords to look for in responses
67 bool expect_tool_calls = false;
68 bool expect_commands = false;
69};
70
71std::vector<ConversationTestCase> GetDefaultTestCases() {
72 return {
73 {
74 .name = "embedded_labels_room_query",
75 .description = "Ask about room names using embedded labels",
76 .user_prompts = {"What is the name of room 5?"},
77 .expected_keywords = {"room", "Tower of Hera", "Moldorm"},
78 .expect_tool_calls = false,
79 .expect_commands = false,
80 },
81 {
82 .name = "embedded_labels_sprite_query",
83 .description = "Ask about sprite names using embedded labels",
84 .user_prompts = {"What is sprite 9?"},
85 .expected_keywords = {"sprite", "Moldorm", "Boss"},
86 .expect_tool_calls = false,
87 .expect_commands = false,
88 },
89 {
90 .name = "embedded_labels_entrance_query",
91 .description = "Ask about entrance names using embedded labels",
92 .user_prompts = {"What is entrance 0?"},
93 .expected_keywords = {"entrance", "Link", "House"},
94 .expect_tool_calls = false,
95 .expect_commands = false,
96 },
97 {
98 .name = "simple_question",
99 .description = "Ask about dungeons in the ROM",
100 .user_prompts = {"What dungeons are in this ROM?"},
101 .expected_keywords = {"dungeon", "palace", "castle"},
102 .expect_tool_calls = true,
103 .expect_commands = false,
104 },
105 {
106 .name = "list_all_rooms",
107 .description = "List all room names with embedded labels",
108 .user_prompts = {"List the first 10 dungeon rooms"},
109 .expected_keywords = {"room", "Ganon", "Hyrule", "Palace"},
110 .expect_tool_calls = true,
111 .expect_commands = false,
112 },
113 {
114 .name = "overworld_tile_search",
115 .description = "Find specific tiles in overworld",
116 .user_prompts = {"Find all trees on the overworld"},
117 .expected_keywords = {"tree", "tile", "map"},
118 .expect_tool_calls = true,
119 .expect_commands = false,
120 },
121 {
122 .name = "multi_step_query",
123 .description = "Ask multiple questions in sequence",
124 .user_prompts = {
125 "What is the name of room 0?",
126 "What sprites are defined in the game?",
127 },
128 .expected_keywords = {"Ganon", "sprite", "room"},
129 .expect_tool_calls = true,
130 .expect_commands = false,
131 },
132 {
133 .name = "map_description",
134 .description = "Get information about a specific map",
135 .user_prompts = {"Describe overworld map 0"},
136 .expected_keywords = {"map", "light world", "tile"},
137 .expect_tool_calls = true,
138 .expect_commands = false,
139 },
140 };
141}
142
143void PrintTestHeader(const ConversationTestCase& test_case) {
144 std::cout << "\n===========================================\n";
145 std::cout << "Test: " << test_case.name << "\n";
146 std::cout << "Description: " << test_case.description << "\n";
147 std::cout << "===========================================\n\n";
148}
149
150void PrintUserPrompt(const std::string& prompt) {
151 std::cout << "๐Ÿ‘ค User: " << prompt << "\n\n";
152}
153
154void PrintAgentResponse(const ChatMessage& response, bool verbose) {
155 std::cout << "๐Ÿค– Agent: " << response.message << "\n\n";
156
157 if (verbose && response.json_pretty.has_value()) {
158 std::cout << "๐Ÿงพ JSON Output:\n" << *response.json_pretty << "\n\n";
159 }
160
161 if (response.table_data.has_value()) {
162 std::cout << "๐Ÿ“Š Table Output:\n";
163 const auto& table = response.table_data.value();
164
165 // Print headers
166 std::cout << " ";
167 for (size_t i = 0; i < table.headers.size(); ++i) {
168 std::cout << table.headers[i];
169 if (i < table.headers.size() - 1) {
170 std::cout << " | ";
171 }
172 }
173 std::cout << "\n ";
174 for (size_t i = 0; i < table.headers.size(); ++i) {
175 std::cout << std::string(table.headers[i].length(), '-');
176 if (i < table.headers.size() - 1) {
177 std::cout << " | ";
178 }
179 }
180 std::cout << "\n";
181
182 // Print rows (limit to 10 for readability)
183 const size_t max_rows = std::min<size_t>(10, table.rows.size());
184 for (size_t i = 0; i < max_rows; ++i) {
185 std::cout << " ";
186 for (size_t j = 0; j < table.rows[i].size(); ++j) {
187 std::cout << table.rows[i][j];
188 if (j < table.rows[i].size() - 1) {
189 std::cout << " | ";
190 }
191 }
192 std::cout << "\n";
193 }
194
195 if (!verbose && table.rows.size() > max_rows) {
196 std::cout << " ... (" << (table.rows.size() - max_rows)
197 << " more rows)\n";
198 }
199
200 if (verbose && table.rows.size() > max_rows) {
201 for (size_t i = max_rows; i < table.rows.size(); ++i) {
202 std::cout << " ";
203 for (size_t j = 0; j < table.rows[i].size(); ++j) {
204 std::cout << table.rows[i][j];
205 if (j < table.rows[i].size() - 1) {
206 std::cout << " | ";
207 }
208 }
209 std::cout << "\n";
210 }
211 }
212 std::cout << "\n";
213 }
214}
215
216bool ValidateResponse(const ChatMessage& response,
217 const ConversationTestCase& test_case) {
218 bool passed = true;
219
220 // Check for expected keywords
221 for (const auto& keyword : test_case.expected_keywords) {
222 if (response.message.find(keyword) == std::string::npos) {
223 std::cout << "โš ๏ธ Warning: Expected keyword '" << keyword
224 << "' not found in response\n";
225 // Don't fail test, just warn
226 }
227 }
228
229 // Check for tool calls (if we have table data, tools were likely called)
230 if (test_case.expect_tool_calls && !response.table_data.has_value()) {
231 std::cout << "โš ๏ธ Warning: Expected tool calls but no table data found\n";
232 }
233
234 // Check for commands
235 if (test_case.expect_commands) {
236 bool has_commands = response.message.find("overworld") != std::string::npos ||
237 response.message.find("dungeon") != std::string::npos ||
238 response.message.find("set-tile") != std::string::npos;
239 if (!has_commands) {
240 std::cout << "โš ๏ธ Warning: Expected commands but none found\n";
241 }
242 }
243
244 return passed;
245}
246
247absl::Status RunTestCase(const ConversationTestCase& test_case,
249 bool verbose) {
250 PrintTestHeader(test_case);
251
252 bool all_passed = true;
253
254 service.ResetConversation();
255
256 for (const auto& prompt : test_case.user_prompts) {
257 PrintUserPrompt(prompt);
258
259 auto response_or = service.SendMessage(prompt);
260 if (!response_or.ok()) {
261 std::cout << "โŒ FAILED: " << response_or.status().message() << "\n\n";
262 all_passed = false;
263 continue;
264 }
265
266 const auto& response = response_or.value();
267 PrintAgentResponse(response, verbose);
268
269 if (!ValidateResponse(response, test_case)) {
270 all_passed = false;
271 }
272 }
273
274 if (verbose) {
275 const auto& history = service.GetHistory();
276 std::cout << "๐Ÿ—‚ Conversation Summary (" << history.size()
277 << " message" << (history.size() == 1 ? "" : "s") << ")\n";
278 for (const auto& message : history) {
279 const char* sender =
280 message.sender == ChatMessage::Sender::kUser ? "User" : "Agent";
281 std::cout << " [" << sender << "] " << message.message << "\n";
282 }
283 std::cout << "\n";
284 }
285
286 if (all_passed) {
287 std::cout << "โœ… Test PASSED: " << test_case.name << "\n";
288 return absl::OkStatus();
289 }
290
291 std::cout << "โš ๏ธ Test completed with warnings: " << test_case.name << "\n";
292 return absl::InternalError(
293 absl::StrCat("Conversation test failed validation: ", test_case.name));
294}
295
296absl::Status LoadTestCasesFromFile(const std::string& file_path,
297 std::vector<ConversationTestCase>* test_cases) {
298 std::ifstream file(file_path);
299 if (!file.is_open()) {
300 return absl::NotFoundError(
301 absl::StrCat("Could not open test file: ", file_path));
302 }
303
304 nlohmann::json test_json;
305 try {
306 file >> test_json;
307 } catch (const nlohmann::json::parse_error& e) {
308 return absl::InvalidArgumentError(
309 absl::StrCat("Failed to parse test file: ", e.what()));
310 }
311
312 if (!test_json.is_array()) {
313 return absl::InvalidArgumentError(
314 "Test file must contain a JSON array of test cases");
315 }
316
317 for (const auto& test_obj : test_json) {
318 ConversationTestCase test_case;
319 test_case.name = test_obj.value("name", "unnamed_test");
320 test_case.description = test_obj.value("description", "");
321
322 if (test_obj.contains("prompts") && test_obj["prompts"].is_array()) {
323 for (const auto& prompt : test_obj["prompts"]) {
324 if (prompt.is_string()) {
325 test_case.user_prompts.push_back(prompt.get<std::string>());
326 }
327 }
328 }
329
330 if (test_obj.contains("expected_keywords") &&
331 test_obj["expected_keywords"].is_array()) {
332 for (const auto& keyword : test_obj["expected_keywords"]) {
333 if (keyword.is_string()) {
334 test_case.expected_keywords.push_back(keyword.get<std::string>());
335 }
336 }
337 }
338
339 test_case.expect_tool_calls = test_obj.value("expect_tool_calls", false);
340 test_case.expect_commands = test_obj.value("expect_commands", false);
341
342 test_cases->push_back(test_case);
343 }
344
345 return absl::OkStatus();
346}
347
348} // namespace
349
351 const std::vector<std::string>& arg_vec) {
352 std::string test_file;
353 bool use_defaults = true;
354 bool verbose = false;
355
356 for (size_t i = 0; i < arg_vec.size(); ++i) {
357 const std::string& arg = arg_vec[i];
358 if (arg == "--file" && i + 1 < arg_vec.size()) {
359 test_file = arg_vec[i + 1];
360 use_defaults = false;
361 ++i;
362 } else if (arg == "--verbose") {
363 verbose = true;
364 }
365 }
366
367 std::cout << "๐Ÿ” Debug: Starting test-conversation handler...\n";
368
369 // Load ROM context
370 Rom rom;
371 std::cout << "๐Ÿ” Debug: Loading ROM...\n";
372 auto load_status = LoadRomForAgent(rom);
373 if (!load_status.ok()) {
374 std::cerr << "โŒ Error loading ROM: " << load_status.message() << "\n";
375 return load_status;
376 }
377
378 std::cout << "โœ… ROM loaded: " << rom.title() << "\n";
379
380 // Load embedded labels for natural language queries
381 std::cout << "๐Ÿ” Debug: Initializing embedded labels...\n";
382 core::YazeProject project;
383 auto labels_status = project.InitializeEmbeddedLabels();
384 if (!labels_status.ok()) {
385 std::cerr << "โš ๏ธ Warning: Could not initialize embedded labels: "
386 << labels_status.message() << "\n";
387 } else {
388 std::cout << "โœ… Embedded labels initialized successfully\n";
389 }
390
391 // Associate labels with ROM if it has a resource label manager
392 std::cout << "๐Ÿ” Debug: Checking resource label manager...\n";
393 if (rom.resource_label() && project.use_embedded_labels) {
394 std::cout << "๐Ÿ” Debug: Associating labels with ROM...\n";
395 rom.resource_label()->labels_ = project.resource_labels;
396 rom.resource_label()->labels_loaded_ = true;
397 std::cout << "โœ… Embedded labels loaded and associated with ROM\n";
398 } else {
399 std::cout << "โš ๏ธ ROM has no resource label manager\n";
400 }
401
402 // Create conversational agent service
403 std::cout << "๐Ÿ” Debug: Creating conversational agent service...\n";
404 std::cout << "๐Ÿ” Debug: About to construct service object...\n";
405
407 std::cout << "โœ… Service object created\n";
408
409 std::cout << "๐Ÿ” Debug: Setting ROM context...\n";
410 service.SetRomContext(&rom);
411 std::cout << "โœ… Service initialized\n";
412
413 // Load test cases
414 std::vector<ConversationTestCase> test_cases;
415 if (use_defaults) {
416 test_cases = GetDefaultTestCases();
417 std::cout << "Using default test cases (" << test_cases.size() << " tests)\n";
418 } else {
419 auto status = LoadTestCasesFromFile(test_file, &test_cases);
420 if (!status.ok()) {
421 return status;
422 }
423 std::cout << "Loaded " << test_cases.size() << " test cases from "
424 << test_file << "\n";
425 }
426
427 if (test_cases.empty()) {
428 return absl::InvalidArgumentError("No test cases to run");
429 }
430
431 // Run all test cases
432 int passed = 0;
433 int failed = 0;
434
435 for (const auto& test_case : test_cases) {
436 auto status = RunTestCase(test_case, service, verbose);
437 if (status.ok()) {
438 ++passed;
439 } else {
440 ++failed;
441 std::cerr << "Test case '" << test_case.name << "' failed: "
442 << status.message() << "\n";
443 }
444 }
445
446 // Print summary
447 std::cout << "\n===========================================\n";
448 std::cout << "Test Summary\n";
449 std::cout << "===========================================\n";
450 std::cout << "Total tests: " << test_cases.size() << "\n";
451 std::cout << "Passed: " << passed << "\n";
452 std::cout << "Failed: " << failed << "\n";
453
454 if (failed == 0) {
455 std::cout << "\nโœ… All tests passed!\n";
456 } else {
457 std::cout << "\nโš ๏ธ Some tests failed\n";
458 }
459
460 if (failed == 0) {
461 return absl::OkStatus();
462 }
463
464 return absl::InternalError(
465 absl::StrCat(failed, " conversation test(s) reported failures"));
466}
467
468} // namespace agent
469} // namespace cli
470} // namespace yaze
The Rom class is used to load, save, and modify Rom data.
Definition rom.h:71
absl::Status LoadFromFile(const std::string &filename, bool z3_load=true)
Definition rom.cc:289
core::ResourceLabelManager * resource_label()
Definition rom.h:220
bool is_loaded() const
Definition rom.h:197
auto title() const
Definition rom.h:201
absl::StatusOr< ChatMessage > SendMessage(const std::string &message)
const std::vector< ChatMessage > & GetHistory() const
ABSL_DECLARE_FLAG(std::string, rom)
bool ValidateResponse(const ChatMessage &response, const ConversationTestCase &test_case)
void PrintTestHeader(const ConversationTestCase &test_case)
absl::Status LoadTestCasesFromFile(const std::string &file_path, std::vector< ConversationTestCase > *test_cases)
void PrintAgentResponse(const ChatMessage &response, bool verbose)
absl::Status RunTestCase(const ConversationTestCase &test_case, ConversationalAgentService &service, bool verbose)
absl::Status HandleTestConversationCommand(const std::vector< std::string > &args)
absl::Status InitializeMockRom(Rom &rom)
Initialize a mock ROM for testing without requiring an actual ROM file.
Definition mock_rom.cc:16
Main namespace for the application.
std::optional< std::string > json_pretty
std::unordered_map< std::string, std::unordered_map< std::string, std::string > > labels_
Definition project.h:235
Modern project structure with comprehensive settings consolidation.
Definition project.h:78
std::unordered_map< std::string, std::unordered_map< std::string, std::string > > resource_labels
Definition project.h:100
absl::Status InitializeEmbeddedLabels()
Definition project.cc:887