Add fuzzy and text-classification light model #1
					 3 changed files with 115 additions and 22 deletions
				
			
		Add new intent classifier, and add args parser to main
				commit
				
					
					
						07c5184f13
					
				
			
		
							
								
								
									
										65
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										65
									
								
								main.py
									
										
									
									
									
								
							|  | @ -1,9 +1,9 @@ | ||||||
| import logging | import logging | ||||||
| import time | import argparse | ||||||
| import schedule | 
 | ||||||
|  | # import schedule | ||||||
| 
 | 
 | ||||||
| import config | import config | ||||||
| 
 |  | ||||||
| from src.rhasspy import rhasspy_mqtt as yoda_listener | from src.rhasspy import rhasspy_mqtt as yoda_listener | ||||||
| from src.rhasspy import Rhasspy | from src.rhasspy import Rhasspy | ||||||
| from src.ratatouille import Ratatouille | from src.ratatouille import Ratatouille | ||||||
|  | @ -11,30 +11,73 @@ from src.mpd import Mpd | ||||||
| from src.hass import Hass | from src.hass import Hass | ||||||
| from src.httpServer import get_server | from src.httpServer import get_server | ||||||
| from src.fuzzy import fuzz_predict | from src.fuzzy import fuzz_predict | ||||||
| # from src.intent import AlexaIntent | from src.intent import BertIntent | ||||||
|  | from src.tools.simple_sentence_corrector import simple_sentence_corrector | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # --------- setup args ------------- | ||||||
|  | 
 | ||||||
|  | parser = argparse.ArgumentParser( | ||||||
|  |     prog='Ratatouille', | ||||||
|  |     description='Ratatouille le cerveau domotique !') | ||||||
|  | 
 | ||||||
|  | parser.add_argument('mode') | ||||||
|  | parser.add_argument('-i', '--ip', required=False) | ||||||
|  | parser.add_argument('-p', '--port', required=False, type=int) | ||||||
|  | 
 | ||||||
|  | args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  | if args.mode == "server": | ||||||
|  |     if args.ip is None or args.port is None: | ||||||
|  |         logging.error(" --ip or --port argument missing") | ||||||
|  |         exit() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # -------- setup logging ------------ | ||||||
| 
 | 
 | ||||||
| logging.basicConfig( | logging.basicConfig( | ||||||
|     level=10, |     level=10, | ||||||
|     format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s" |     format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| IP = "10.10.10.8" | logging.info("Loading ratatouilles modules") | ||||||
| PORT = 5555 | 
 | ||||||
|  | 
 | ||||||
|  | # ---------- other --------------- | ||||||
| 
 | 
 | ||||||
| walle = Hass(config.hass_url, config.hass_token) | walle = Hass(config.hass_url, config.hass_token) | ||||||
| yoda = None  # Rhasspy(config.rhasspy_url) | yoda = None  # Rhasspy(config.rhasspy_url) | ||||||
| mopidy = Mpd('10.10.10.8') | mopidy = Mpd('10.10.10.8') | ||||||
| ratatouille = Ratatouille(yoda, walle, mopidy, schedule) | ratatouille = Ratatouille(yoda, walle, mopidy, None) | ||||||
| # alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset | # alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset | ||||||
|  | bert = BertIntent() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def answer(sentence): | def answer(sentence): | ||||||
|     # return ratatouille.parse_alexa(alexa.predict(sentence)) |     # return ratatouille.parse_alexa(alexa.predict(sentence)) | ||||||
|     return ratatouille.parse_fuzzy(fuzz_predict(sentence)) |     sentence_corrected = simple_sentence_corrector(sentence) | ||||||
|  |     prediction = bert.predict(sentence_corrected) | ||||||
|  |     return ratatouille.parse_fuzzy(prediction) | ||||||
|     # return "42" |     # return "42" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| server = get_server(IP, PORT, answer) | def run_server(ip, port): | ||||||
|  |     server = get_server(ip, port, answer) | ||||||
|  |     logging.info('Running server on '+ip+':'+str(port)) | ||||||
|  |     server.serve_forever() | ||||||
| 
 | 
 | ||||||
| logging.info('Running server on '+IP+':'+str(PORT)) | 
 | ||||||
| server.serve_forever() | def run_prompt(): | ||||||
|  |     question = "empty" | ||||||
|  |     while question != "stop": | ||||||
|  |         question = input("?") | ||||||
|  |         print(answer(question)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | logging.info("Ratatouille is ready !") | ||||||
|  | 
 | ||||||
|  | # run_server() | ||||||
|  | if args.mode == "server": | ||||||
|  |     run_server(str(args.ip), args.port) | ||||||
|  | else: | ||||||
|  |     run_prompt() | ||||||
|  |  | ||||||
|  | @ -1,12 +1,14 @@ | ||||||
| import unittest | import unittest | ||||||
| from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline | from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline | ||||||
| from transformers import AutoModelForSequenceClassification, TextClassificationPipeline | from transformers import AutoModelForSequenceClassification, TextClassificationPipeline | ||||||
|  | from transformers import pipeline | ||||||
| 
 | 
 | ||||||
| DOMOTIQUE_OBJ_ON = ['iot_wemo_on'] | DOMOTIQUE_OBJ_ON = ['iot_wemo_on'] | ||||||
| DOMOTIQUE_OBJ_OFF = ['iot_wemo_off'] | DOMOTIQUE_OBJ_OFF = ['iot_wemo_off'] | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class AlexaIntent(): | class AlexaIntent(): | ||||||
|      | 
 | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.get_intents = self.init_intent_classification() |         self.get_intents = self.init_intent_classification() | ||||||
|         self.get_entities = self.init_entities_classification() |         self.get_entities = self.init_entities_classification() | ||||||
|  | @ -15,21 +17,24 @@ class AlexaIntent(): | ||||||
|         model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification' |         model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification' | ||||||
|         tokenizer = AutoTokenizer.from_pretrained(model_name) |         tokenizer = AutoTokenizer.from_pretrained(model_name) | ||||||
|         model = AutoModelForSequenceClassification.from_pretrained(model_name) |         model = AutoModelForSequenceClassification.from_pretrained(model_name) | ||||||
|         classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer) |         classifier = TextClassificationPipeline( | ||||||
|  |             model=model, tokenizer=tokenizer) | ||||||
|         return classifier |         return classifier | ||||||
| 
 | 
 | ||||||
|     def init_entities_classification(self): |     def init_entities_classification(self): | ||||||
|         tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') |         tokenizer = AutoTokenizer.from_pretrained( | ||||||
|         model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') |             'qanastek/XLMRoberta-Alexa-Intents-NER-NLU') | ||||||
|  |         model = AutoModelForTokenClassification.from_pretrained( | ||||||
|  |             'qanastek/XLMRoberta-Alexa-Intents-NER-NLU') | ||||||
|         predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer) |         predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer) | ||||||
|         return predict |         return predict | ||||||
| 
 | 
 | ||||||
|     def simple_sentence_corrector(self,sentence):     |     def simple_sentence_corrector(self, sentence): | ||||||
|         sentence = sentence.replace('étant','éteins') |         sentence = sentence.replace('étant', 'éteins') | ||||||
|         sentence = sentence.replace('dépeint','éteins') |         sentence = sentence.replace('dépeint', 'éteins') | ||||||
|         return sentence |         return sentence | ||||||
| 
 | 
 | ||||||
|     def intent_corrector(self,intents): |     def intent_corrector(self, intents): | ||||||
|         for intent in intents: |         for intent in intents: | ||||||
|             if intent['label'] in DOMOTIQUE_OBJ_ON: |             if intent['label'] in DOMOTIQUE_OBJ_ON: | ||||||
|                 intent['label'] = 'iot_hue_lighton' |                 intent['label'] = 'iot_hue_lighton' | ||||||
|  | @ -37,7 +42,7 @@ class AlexaIntent(): | ||||||
|                 intent['label'] = 'iot_hue_lightoff' |                 intent['label'] = 'iot_hue_lightoff' | ||||||
|         return intents |         return intents | ||||||
| 
 | 
 | ||||||
|     def predict(self,sentence): |     def predict(self, sentence): | ||||||
|         sentence = self.simple_sentence_corrector(sentence) |         sentence = self.simple_sentence_corrector(sentence) | ||||||
|         return { |         return { | ||||||
|             'intents': self.intent_corrector(self.get_intents(sentence)), |             'intents': self.intent_corrector(self.get_intents(sentence)), | ||||||
|  | @ -45,6 +50,47 @@ class AlexaIntent(): | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class BertIntent(): | ||||||
|  |     def __init__(self): | ||||||
|  |         self.classifier = pipeline("text-classification", | ||||||
|  |                                    model="Tjiho/french-intents-classificaton") | ||||||
|  | 
 | ||||||
|  |     def predict(self, sentence): | ||||||
|  |         # sentence = self.simple_sentence_corrector(sentence) | ||||||
|  |         classification = self.classifier(sentence) | ||||||
|  | 
 | ||||||
|  |         if classification[0]["score"] < 0.7:  # score too low | ||||||
|  |             return | ||||||
|  | 
 | ||||||
|  |         label = classification[0]["label"] | ||||||
|  | 
 | ||||||
|  |         if label == "HEURE": | ||||||
|  |             return {'intentName': 'GetTime'} | ||||||
|  |         elif label == "DATE": | ||||||
|  |             return {'intentName': 'GetDate'} | ||||||
|  |         elif label == "ETEINDRE_CUISINE": | ||||||
|  |             return {'intentName': 'GetTime', 'intentArg': ['cuisine']} | ||||||
|  |         elif label == "ETEINDRE_BUREAU": | ||||||
|  |             return {'intentName': 'LightOff', 'intentArg': ['bureau']} | ||||||
|  |         elif label == "ETEINDRE_SALON": | ||||||
|  |             return {'intentName': 'LightOff', 'intentArg': ['salon']} | ||||||
|  |         elif label == "ETEINDRE_CHAMBRE": | ||||||
|  |             return {'intentName': 'LightOff', 'intentArg': ['chambre']} | ||||||
|  |         elif label == "ALLUMER_CUISINE": | ||||||
|  |             return {'intentName': 'LightOn', 'intentArg': ['cuisine']} | ||||||
|  |         elif label == "ALLUMER_SALON": | ||||||
|  |             return {'intentName': 'LightOn', 'intentArg': ['salon']} | ||||||
|  |         elif label == "ALLUMER_BUREAU": | ||||||
|  |             return {'intentName': 'LightOn', 'intentArg': ['bureau']} | ||||||
|  |         elif label == "ALLUMER_CHAMBRE": | ||||||
|  |             return {'intentName': 'LightOn', 'intentArg': ['chambre']} | ||||||
|  |         elif label == "METEO": | ||||||
|  |             return {'intentName': 'Meteo'} | ||||||
|  |         elif label == "TEMPERATURE_EXTERIEUR": | ||||||
|  |             return {'intentName': 'Temperature_ext'} | ||||||
|  |         elif label == "TEMPERATURE_INTERIEUR": | ||||||
|  |             return {'intentName': 'Temperature_int'} | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class TestAlexa(unittest.TestCase): | class TestAlexa(unittest.TestCase): | ||||||
|     @classmethod |     @classmethod | ||||||
|  | @ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase): | ||||||
|         self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton') |         self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton') | ||||||
|         self.assertEqual(res['entities'][0]['word'], '▁cuisine') |         self.assertEqual(res['entities'][0]['word'], '▁cuisine') | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     def test_bad_transcribe(self): |     def test_bad_transcribe(self): | ||||||
|         res = self.alexa.predict("dépeint la cuisine") |         res = self.alexa.predict("dépeint la cuisine") | ||||||
|         self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff') |         self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff') | ||||||
|         self.assertEqual(res['entities'][0]['word'], '▁cuisine') |         self.assertEqual(res['entities'][0]['word'], '▁cuisine') | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								src/tools/simple_sentence_corrector.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								src/tools/simple_sentence_corrector.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,4 @@ | ||||||
|  | def simple_sentence_corrector(sentence): | ||||||
|  |     sentence = sentence.replace('étant', 'éteins') | ||||||
|  |     sentence = sentence.replace('dépeint', 'éteins') | ||||||
|  |     return sentence | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue