|
2 | 2 | import logging |
3 | 3 | import re |
4 | 4 | from collections import defaultdict |
| 5 | +from rasa.cli.utils import bcolors |
5 | 6 |
|
6 | 7 | from rasa.core.trackers import DialogueStateTracker |
7 | 8 | from typing import Text, Any, Dict, Optional, List |
|
10 | 11 |
|
11 | 12 | logger = logging.getLogger(__name__) |
12 | 13 |
|
| 14 | +class InvalidNLGRequest(Exception): |
| 15 | + def __init__(self, message): |
| 16 | + self.message = message |
| 17 | + |
| 18 | + def __str__(self): |
| 19 | + return bcolors.FAIL + self.message + bcolors.ENDC |
13 | 20 |
|
14 | 21 | class TemplatedNaturalLanguageGenerator(NaturalLanguageGenerator): |
15 | 22 | """Natural language generator that generates messages based on templates. |
16 | 23 |
|
17 | 24 | The templates can use variables to customize the utterances based on the |
18 | 25 | state of the dialogue.""" |
19 | 26 |
|
20 | | - def __init__(self, templates: Dict[Text, List[Dict[Text, Any]]]) -> None: |
| 27 | + def __init__(self, templates: Dict[Text, Dict[Text, List[List[Dict[Text, Any]]]]]) -> None: |
21 | 28 | self.templates = templates |
22 | 29 |
|
23 | 30 | def _templates_for_utter_action(self, utter_action, output_channel): |
@@ -71,10 +78,28 @@ async def generate( |
71 | 78 | """Generate a response for the requested template.""" |
72 | 79 |
|
73 | 80 | filled_slots = tracker.current_slot_values() |
74 | | - return self.generate_from_slots( |
| 81 | + return self.generate_from_bf_template( |
75 | 82 | template_name, filled_slots, output_channel, **kwargs |
76 | 83 | ) |
77 | 84 |
|
| 85 | + def generate_from_bf_template( |
| 86 | + self, |
| 87 | + template_name: Text, |
| 88 | + filled_slots: Dict[Text, Any], |
| 89 | + output_channel: Text, |
| 90 | + **kwargs: Any |
| 91 | + ) -> Optional[Dict[Text, Any]]: |
| 92 | + |
| 93 | + language = kwargs.get("language", None) |
| 94 | + if not language: |
| 95 | + raise InvalidNLGRequest("Generator expected a language to return template") |
| 96 | + if template_name not in self.templates: |
| 97 | + return None |
| 98 | + |
| 99 | + return [self._fill_template_text( |
| 100 | + copy.deepcopy(template), filled_slots, **kwargs |
| 101 | + ) for template in self.templates[template_name][language]] |
| 102 | + |
78 | 103 | def generate_from_slots( |
79 | 104 | self, |
80 | 105 | template_name: Text, |
|
0 commit comments