yaze 0.3.2
Link to the Past ROM Editor
 
Loading...
Searching...
No Matches
conversation_test.cc
Go to the documentation of this file.
1#include "app/rom.h"
2#include "core/project.h"
4
5#include "absl/flags/declare.h"
6#include "absl/flags/flag.h"
7#include <fstream>
8#include <iostream>
9#include <string>
10#include <vector>
11
12#include "absl/status/status.h"
13#include "absl/strings/str_cat.h"
16#include "nlohmann/json.hpp"
17
18ABSL_DECLARE_FLAG(std::string, rom);
19ABSL_DECLARE_FLAG(bool, mock_rom);
20
21namespace yaze {
22namespace cli {
23namespace agent {
24
25namespace {
26
27absl::Status LoadRomForAgent(Rom& rom) {
28 if (rom.is_loaded()) {
29 return ::absl::OkStatus();
30 }
31
32 // Check if mock ROM mode is enabled
33 bool use_mock = ::absl::GetFlag(FLAGS_mock_rom);
34 if (use_mock) {
35 // Initialize mock ROM with embedded labels
36 auto status = InitializeMockRom(rom);
37 if (!status.ok()) {
38 return status;
39 }
40 std::cout << "โœ… Mock ROM initialized with embedded Zelda3 labels\n";
41 return ::absl::OkStatus();
42 }
43
44 // Otherwise load from file
45 std::string rom_path = ::absl::GetFlag(FLAGS_rom);
46 if (rom_path.empty()) {
47 return ::absl::InvalidArgumentError(
48 "No ROM loaded. Pass --rom=<path> or use --mock-rom for testing.");
49 }
50
51 auto status = rom.LoadFromFile(rom_path);
52 if (!status.ok()) {
53 return ::absl::FailedPreconditionError(
54 ::absl::StrCat("Failed to load ROM from '", rom_path,
55 "': ", status.message()));
56 }
57
58 return ::absl::OkStatus();
59}
60
62 std::string name;
63 std::string description;
64 std::vector<std::string> user_prompts;
65 std::vector<std::string> expected_keywords; // Keywords to look for in responses
66 bool expect_tool_calls = false;
67 bool expect_commands = false;
68};
69
70std::vector<ConversationTestCase> GetDefaultTestCases() {
71 return {
72 {
73 .name = "embedded_labels_room_query",
74 .description = "Ask about room names using embedded labels",
75 .user_prompts = {"What is the name of room 5?"},
76 .expected_keywords = {"room", "Tower of Hera", "Moldorm"},
77 .expect_tool_calls = false,
78 .expect_commands = false,
79 },
80 {
81 .name = "embedded_labels_sprite_query",
82 .description = "Ask about sprite names using embedded labels",
83 .user_prompts = {"What is sprite 9?"},
84 .expected_keywords = {"sprite", "Moldorm", "Boss"},
85 .expect_tool_calls = false,
86 .expect_commands = false,
87 },
88 {
89 .name = "embedded_labels_entrance_query",
90 .description = "Ask about entrance names using embedded labels",
91 .user_prompts = {"What is entrance 0?"},
92 .expected_keywords = {"entrance", "Link", "House"},
93 .expect_tool_calls = false,
94 .expect_commands = false,
95 },
96 {
97 .name = "simple_question",
98 .description = "Ask about dungeons in the ROM",
99 .user_prompts = {"What dungeons are in this ROM?"},
100 .expected_keywords = {"dungeon", "palace", "castle"},
101 .expect_tool_calls = true,
102 .expect_commands = false,
103 },
104 {
105 .name = "list_all_rooms",
106 .description = "List all room names with embedded labels",
107 .user_prompts = {"List the first 10 dungeon rooms"},
108 .expected_keywords = {"room", "Ganon", "Hyrule", "Palace"},
109 .expect_tool_calls = true,
110 .expect_commands = false,
111 },
112 {
113 .name = "overworld_tile_search",
114 .description = "Find specific tiles in overworld",
115 .user_prompts = {"Find all trees on the overworld"},
116 .expected_keywords = {"tree", "tile", "map"},
117 .expect_tool_calls = true,
118 .expect_commands = false,
119 },
120 {
121 .name = "multi_step_query",
122 .description = "Ask multiple questions in sequence",
123 .user_prompts = {
124 "What is the name of room 0?",
125 "What sprites are defined in the game?",
126 },
127 .expected_keywords = {"Ganon", "sprite", "room"},
128 .expect_tool_calls = true,
129 .expect_commands = false,
130 },
131 {
132 .name = "map_description",
133 .description = "Get information about a specific map",
134 .user_prompts = {"Describe overworld map 0"},
135 .expected_keywords = {"map", "light world", "tile"},
136 .expect_tool_calls = true,
137 .expect_commands = false,
138 },
139 };
140}
141
142void PrintTestHeader(const ConversationTestCase& test_case) {
143 std::cout << "\n===========================================\n";
144 std::cout << "Test: " << test_case.name << "\n";
145 std::cout << "Description: " << test_case.description << "\n";
146 std::cout << "===========================================\n\n";
147}
148
149void PrintUserPrompt(const std::string& prompt) {
150 std::cout << "๐Ÿ‘ค User: " << prompt << "\n\n";
151}
152
153void PrintAgentResponse(const ChatMessage& response, bool verbose) {
154 std::cout << "๐Ÿค– Agent: " << response.message << "\n\n";
155
156 if (verbose && response.json_pretty.has_value()) {
157 std::cout << "๐Ÿงพ JSON Output:\n" << *response.json_pretty << "\n\n";
158 }
159
160 if (response.table_data.has_value()) {
161 std::cout << "๐Ÿ“Š Table Output:\n";
162 const auto& table = response.table_data.value();
163
164 // Print headers
165 std::cout << " ";
166 for (size_t i = 0; i < table.headers.size(); ++i) {
167 std::cout << table.headers[i];
168 if (i < table.headers.size() - 1) {
169 std::cout << " | ";
170 }
171 }
172 std::cout << "\n ";
173 for (size_t i = 0; i < table.headers.size(); ++i) {
174 std::cout << std::string(table.headers[i].length(), '-');
175 if (i < table.headers.size() - 1) {
176 std::cout << " | ";
177 }
178 }
179 std::cout << "\n";
180
181 // Print rows (limit to 10 for readability)
182 const size_t max_rows = std::min<size_t>(10, table.rows.size());
183 for (size_t i = 0; i < max_rows; ++i) {
184 std::cout << " ";
185 for (size_t j = 0; j < table.rows[i].size(); ++j) {
186 std::cout << table.rows[i][j];
187 if (j < table.rows[i].size() - 1) {
188 std::cout << " | ";
189 }
190 }
191 std::cout << "\n";
192 }
193
194 if (!verbose && table.rows.size() > max_rows) {
195 std::cout << " ... (" << (table.rows.size() - max_rows)
196 << " more rows)\n";
197 }
198
199 if (verbose && table.rows.size() > max_rows) {
200 for (size_t i = max_rows; i < table.rows.size(); ++i) {
201 std::cout << " ";
202 for (size_t j = 0; j < table.rows[i].size(); ++j) {
203 std::cout << table.rows[i][j];
204 if (j < table.rows[i].size() - 1) {
205 std::cout << " | ";
206 }
207 }
208 std::cout << "\n";
209 }
210 }
211 std::cout << "\n";
212 }
213}
214
215bool ValidateResponse(const ChatMessage& response,
216 const ConversationTestCase& test_case) {
217 bool passed = true;
218
219 // Check for expected keywords
220 for (const auto& keyword : test_case.expected_keywords) {
221 if (response.message.find(keyword) == std::string::npos) {
222 std::cout << "โš ๏ธ Warning: Expected keyword '" << keyword
223 << "' not found in response\n";
224 // Don't fail test, just warn
225 }
226 }
227
228 // Check for tool calls (if we have table data, tools were likely called)
229 if (test_case.expect_tool_calls && !response.table_data.has_value()) {
230 std::cout << "โš ๏ธ Warning: Expected tool calls but no table data found\n";
231 }
232
233 // Check for commands
234 if (test_case.expect_commands) {
235 bool has_commands = response.message.find("overworld") != std::string::npos ||
236 response.message.find("dungeon") != std::string::npos ||
237 response.message.find("set-tile") != std::string::npos;
238 if (!has_commands) {
239 std::cout << "โš ๏ธ Warning: Expected commands but none found\n";
240 }
241 }
242
243 return passed;
244}
245
246absl::Status RunTestCase(const ConversationTestCase& test_case,
248 bool verbose) {
249 PrintTestHeader(test_case);
250
251 bool all_passed = true;
252
253 service.ResetConversation();
254
255 for (const auto& prompt : test_case.user_prompts) {
256 PrintUserPrompt(prompt);
257
258 auto response_or = service.SendMessage(prompt);
259 if (!response_or.ok()) {
260 std::cout << "โŒ FAILED: " << response_or.status().message() << "\n\n";
261 all_passed = false;
262 continue;
263 }
264
265 const auto& response = response_or.value();
266 PrintAgentResponse(response, verbose);
267
268 if (!ValidateResponse(response, test_case)) {
269 all_passed = false;
270 }
271 }
272
273 if (verbose) {
274 const auto& history = service.GetHistory();
275 std::cout << "๐Ÿ—‚ Conversation Summary (" << history.size()
276 << " message" << (history.size() == 1 ? "" : "s") << ")\n";
277 for (const auto& message : history) {
278 const char* sender =
279 message.sender == ChatMessage::Sender::kUser ? "User" : "Agent";
280 std::cout << " [" << sender << "] " << message.message << "\n";
281 }
282 std::cout << "\n";
283 }
284
285 if (all_passed) {
286 std::cout << "โœ… Test PASSED: " << test_case.name << "\n";
287 return absl::OkStatus();
288 }
289
290 std::cout << "โš ๏ธ Test completed with warnings: " << test_case.name << "\n";
291 return absl::InternalError(
292 absl::StrCat("Conversation test failed validation: ", test_case.name));
293}
294
295absl::Status LoadTestCasesFromFile(const std::string& file_path,
296 std::vector<ConversationTestCase>* test_cases) {
297 std::ifstream file(file_path);
298 if (!file.is_open()) {
299 return absl::NotFoundError(
300 absl::StrCat("Could not open test file: ", file_path));
301 }
302
303 nlohmann::json test_json;
304 try {
305 file >> test_json;
306 } catch (const nlohmann::json::parse_error& e) {
307 return absl::InvalidArgumentError(
308 absl::StrCat("Failed to parse test file: ", e.what()));
309 }
310
311 if (!test_json.is_array()) {
312 return absl::InvalidArgumentError(
313 "Test file must contain a JSON array of test cases");
314 }
315
316 for (const auto& test_obj : test_json) {
317 ConversationTestCase test_case;
318 test_case.name = test_obj.value("name", "unnamed_test");
319 test_case.description = test_obj.value("description", "");
320
321 if (test_obj.contains("prompts") && test_obj["prompts"].is_array()) {
322 for (const auto& prompt : test_obj["prompts"]) {
323 if (prompt.is_string()) {
324 test_case.user_prompts.push_back(prompt.get<std::string>());
325 }
326 }
327 }
328
329 if (test_obj.contains("expected_keywords") &&
330 test_obj["expected_keywords"].is_array()) {
331 for (const auto& keyword : test_obj["expected_keywords"]) {
332 if (keyword.is_string()) {
333 test_case.expected_keywords.push_back(keyword.get<std::string>());
334 }
335 }
336 }
337
338 test_case.expect_tool_calls = test_obj.value("expect_tool_calls", false);
339 test_case.expect_commands = test_obj.value("expect_commands", false);
340
341 test_cases->push_back(test_case);
342 }
343
344 return absl::OkStatus();
345}
346
347} // namespace
348
350 const std::vector<std::string>& arg_vec) {
351 std::string test_file;
352 bool use_defaults = true;
353 bool verbose = false;
354
355 for (size_t i = 0; i < arg_vec.size(); ++i) {
356 const std::string& arg = arg_vec[i];
357 if (arg == "--file" && i + 1 < arg_vec.size()) {
358 test_file = arg_vec[i + 1];
359 use_defaults = false;
360 ++i;
361 } else if (arg == "--verbose") {
362 verbose = true;
363 }
364 }
365
366 std::cout << "๐Ÿ” Debug: Starting test-conversation handler...\n";
367
368 // Load ROM context
369 Rom rom;
370 std::cout << "๐Ÿ” Debug: Loading ROM...\n";
371 auto load_status = LoadRomForAgent(rom);
372 if (!load_status.ok()) {
373 std::cerr << "โŒ Error loading ROM: " << load_status.message() << "\n";
374 return load_status;
375 }
376
377 std::cout << "โœ… ROM loaded: " << rom.title() << "\n";
378
379 // Load embedded labels for natural language queries
380 std::cout << "๐Ÿ” Debug: Initializing embedded labels...\n";
381 project::YazeProject project;
382 auto labels_status = project.InitializeEmbeddedLabels();
383 if (!labels_status.ok()) {
384 std::cerr << "โš ๏ธ Warning: Could not initialize embedded labels: "
385 << labels_status.message() << "\n";
386 } else {
387 std::cout << "โœ… Embedded labels initialized successfully\n";
388 }
389
390 // Associate labels with ROM if it has a resource label manager
391 std::cout << "๐Ÿ” Debug: Checking resource label manager...\n";
392 if (rom.resource_label() && project.use_embedded_labels) {
393 std::cout << "๐Ÿ” Debug: Associating labels with ROM...\n";
394 rom.resource_label()->labels_ = project.resource_labels;
395 rom.resource_label()->labels_loaded_ = true;
396 std::cout << "โœ… Embedded labels loaded and associated with ROM\n";
397 } else {
398 std::cout << "โš ๏ธ ROM has no resource label manager\n";
399 }
400
401 // Create conversational agent service
402 std::cout << "๐Ÿ” Debug: Creating conversational agent service...\n";
403 std::cout << "๐Ÿ” Debug: About to construct service object...\n";
404
406 std::cout << "โœ… Service object created\n";
407
408 std::cout << "๐Ÿ” Debug: Setting ROM context...\n";
409 service.SetRomContext(&rom);
410 std::cout << "โœ… Service initialized\n";
411
412 // Load test cases
413 std::vector<ConversationTestCase> test_cases;
414 if (use_defaults) {
415 test_cases = GetDefaultTestCases();
416 std::cout << "Using default test cases (" << test_cases.size() << " tests)\n";
417 } else {
418 auto status = LoadTestCasesFromFile(test_file, &test_cases);
419 if (!status.ok()) {
420 return status;
421 }
422 std::cout << "Loaded " << test_cases.size() << " test cases from "
423 << test_file << "\n";
424 }
425
426 if (test_cases.empty()) {
427 return absl::InvalidArgumentError("No test cases to run");
428 }
429
430 // Run all test cases
431 int passed = 0;
432 int failed = 0;
433
434 for (const auto& test_case : test_cases) {
435 auto status = RunTestCase(test_case, service, verbose);
436 if (status.ok()) {
437 ++passed;
438 } else {
439 ++failed;
440 std::cerr << "Test case '" << test_case.name << "' failed: "
441 << status.message() << "\n";
442 }
443 }
444
445 // Print summary
446 std::cout << "\n===========================================\n";
447 std::cout << "Test Summary\n";
448 std::cout << "===========================================\n";
449 std::cout << "Total tests: " << test_cases.size() << "\n";
450 std::cout << "Passed: " << passed << "\n";
451 std::cout << "Failed: " << failed << "\n";
452
453 if (failed == 0) {
454 std::cout << "\nโœ… All tests passed!\n";
455 } else {
456 std::cout << "\nโš ๏ธ Some tests failed\n";
457 }
458
459 if (failed == 0) {
460 return absl::OkStatus();
461 }
462
463 return absl::InternalError(
464 absl::StrCat(failed, " conversation test(s) reported failures"));
465}
466
467} // namespace agent
468} // namespace cli
469} // namespace yaze
The Rom class is used to load, save, and modify Rom data.
Definition rom.h:74
project::ResourceLabelManager * resource_label()
Definition rom.h:223
absl::Status LoadFromFile(const std::string &filename, bool z3_load=true)
Definition rom.cc:292
bool is_loaded() const
Definition rom.h:200
auto title() const
Definition rom.h:204
absl::StatusOr< ChatMessage > SendMessage(const std::string &message)
const std::vector< ChatMessage > & GetHistory() const
ABSL_DECLARE_FLAG(std::string, rom)
bool ValidateResponse(const ChatMessage &response, const ConversationTestCase &test_case)
void PrintTestHeader(const ConversationTestCase &test_case)
absl::Status LoadTestCasesFromFile(const std::string &file_path, std::vector< ConversationTestCase > *test_cases)
void PrintAgentResponse(const ChatMessage &response, bool verbose)
absl::Status RunTestCase(const ConversationTestCase &test_case, ConversationalAgentService &service, bool verbose)
absl::Status HandleTestConversationCommand(const std::vector< std::string > &args)
absl::Status InitializeMockRom(Rom &rom)
Initialize a mock ROM for testing without requiring an actual ROM file.
Definition mock_rom.cc:16
Main namespace for the application.
Definition controller.cc:20
std::optional< std::string > json_pretty
std::unordered_map< std::string, std::unordered_map< std::string, std::string > > labels_
Definition project.h:235
Modern project structure with comprehensive settings consolidation.
Definition project.h:78
absl::Status InitializeEmbeddedLabels()
Definition project.cc:887
std::unordered_map< std::string, std::unordered_map< std::string, std::string > > resource_labels
Definition project.h:100