Add fuzzy and text-classification light model #1
					 3 changed files with 115 additions and 22 deletions
				
			
		Add new intent classifier, and add args parser to main
				commit
				
					
					
						07c5184f13
					
				
			
		
							
								
								
									
										65
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										65
									
								
								main.py
									
										
									
									
									
								
							|  | @ -1,9 +1,9 @@ | |||
| import logging | ||||
| import time | ||||
| import schedule | ||||
| import argparse | ||||
| 
 | ||||
| # import schedule | ||||
| 
 | ||||
| import config | ||||
| 
 | ||||
| from src.rhasspy import rhasspy_mqtt as yoda_listener | ||||
| from src.rhasspy import Rhasspy | ||||
| from src.ratatouille import Ratatouille | ||||
|  | @ -11,30 +11,73 @@ from src.mpd import Mpd | |||
| from src.hass import Hass | ||||
| from src.httpServer import get_server | ||||
| from src.fuzzy import fuzz_predict | ||||
| # from src.intent import AlexaIntent | ||||
| from src.intent import BertIntent | ||||
| from src.tools.simple_sentence_corrector import simple_sentence_corrector | ||||
| 
 | ||||
| 
 | ||||
| # --------- setup args ------------- | ||||
| 
 | ||||
| parser = argparse.ArgumentParser( | ||||
|     prog='Ratatouille', | ||||
|     description='Ratatouille le cerveau domotique !') | ||||
| 
 | ||||
| parser.add_argument('mode') | ||||
| parser.add_argument('-i', '--ip', required=False) | ||||
| parser.add_argument('-p', '--port', required=False, type=int) | ||||
| 
 | ||||
| args = parser.parse_args() | ||||
| 
 | ||||
| if args.mode == "server": | ||||
|     if args.ip is None or args.port is None: | ||||
|         logging.error(" --ip or --port argument missing") | ||||
|         exit() | ||||
| 
 | ||||
| 
 | ||||
| # -------- setup logging ------------ | ||||
| 
 | ||||
| logging.basicConfig( | ||||
|     level=10, | ||||
|     format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s" | ||||
| ) | ||||
| 
 | ||||
| IP = "10.10.10.8" | ||||
| PORT = 5555 | ||||
| logging.info("Loading ratatouilles modules") | ||||
| 
 | ||||
| 
 | ||||
| # ---------- other --------------- | ||||
| 
 | ||||
| walle = Hass(config.hass_url, config.hass_token) | ||||
| yoda = None  # Rhasspy(config.rhasspy_url) | ||||
| mopidy = Mpd('10.10.10.8') | ||||
| ratatouille = Ratatouille(yoda, walle, mopidy, schedule) | ||||
| ratatouille = Ratatouille(yoda, walle, mopidy, None) | ||||
| # alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset | ||||
| bert = BertIntent() | ||||
| 
 | ||||
| 
 | ||||
| def answer(sentence): | ||||
|     # return ratatouille.parse_alexa(alexa.predict(sentence)) | ||||
|     return ratatouille.parse_fuzzy(fuzz_predict(sentence)) | ||||
|     sentence_corrected = simple_sentence_corrector(sentence) | ||||
|     prediction = bert.predict(sentence_corrected) | ||||
|     return ratatouille.parse_fuzzy(prediction) | ||||
|     # return "42" | ||||
| 
 | ||||
| 
 | ||||
| server = get_server(IP, PORT, answer) | ||||
| def run_server(ip, port): | ||||
|     server = get_server(ip, port, answer) | ||||
|     logging.info('Running server on '+ip+':'+str(port)) | ||||
|     server.serve_forever() | ||||
| 
 | ||||
| logging.info('Running server on '+IP+':'+str(PORT)) | ||||
| server.serve_forever() | ||||
| 
 | ||||
| def run_prompt(): | ||||
|     question = "empty" | ||||
|     while question != "stop": | ||||
|         question = input("?") | ||||
|         print(answer(question)) | ||||
| 
 | ||||
| 
 | ||||
| logging.info("Ratatouille is ready !") | ||||
| 
 | ||||
| # run_server() | ||||
| if args.mode == "server": | ||||
|     run_server(str(args.ip), args.port) | ||||
| else: | ||||
|     run_prompt() | ||||
|  |  | |||
|  | @ -1,10 +1,12 @@ | |||
| import unittest | ||||
| from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline | ||||
| from transformers import AutoModelForSequenceClassification, TextClassificationPipeline | ||||
| from transformers import pipeline | ||||
| 
 | ||||
| DOMOTIQUE_OBJ_ON = ['iot_wemo_on'] | ||||
| DOMOTIQUE_OBJ_OFF = ['iot_wemo_off'] | ||||
| 
 | ||||
| 
 | ||||
| class AlexaIntent(): | ||||
| 
 | ||||
|     def __init__(self): | ||||
|  | @ -15,21 +17,24 @@ class AlexaIntent(): | |||
|         model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification' | ||||
|         tokenizer = AutoTokenizer.from_pretrained(model_name) | ||||
|         model = AutoModelForSequenceClassification.from_pretrained(model_name) | ||||
|         classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer) | ||||
|         classifier = TextClassificationPipeline( | ||||
|             model=model, tokenizer=tokenizer) | ||||
|         return classifier | ||||
| 
 | ||||
|     def init_entities_classification(self): | ||||
|         tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') | ||||
|         model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') | ||||
|         tokenizer = AutoTokenizer.from_pretrained( | ||||
|             'qanastek/XLMRoberta-Alexa-Intents-NER-NLU') | ||||
|         model = AutoModelForTokenClassification.from_pretrained( | ||||
|             'qanastek/XLMRoberta-Alexa-Intents-NER-NLU') | ||||
|         predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer) | ||||
|         return predict | ||||
| 
 | ||||
|     def simple_sentence_corrector(self,sentence):     | ||||
|         sentence = sentence.replace('étant','éteins') | ||||
|         sentence = sentence.replace('dépeint','éteins') | ||||
|     def simple_sentence_corrector(self, sentence): | ||||
|         sentence = sentence.replace('étant', 'éteins') | ||||
|         sentence = sentence.replace('dépeint', 'éteins') | ||||
|         return sentence | ||||
| 
 | ||||
|     def intent_corrector(self,intents): | ||||
|     def intent_corrector(self, intents): | ||||
|         for intent in intents: | ||||
|             if intent['label'] in DOMOTIQUE_OBJ_ON: | ||||
|                 intent['label'] = 'iot_hue_lighton' | ||||
|  | @ -37,7 +42,7 @@ class AlexaIntent(): | |||
|                 intent['label'] = 'iot_hue_lightoff' | ||||
|         return intents | ||||
| 
 | ||||
|     def predict(self,sentence): | ||||
|     def predict(self, sentence): | ||||
|         sentence = self.simple_sentence_corrector(sentence) | ||||
|         return { | ||||
|             'intents': self.intent_corrector(self.get_intents(sentence)), | ||||
|  | @ -45,6 +50,47 @@ class AlexaIntent(): | |||
|         } | ||||
| 
 | ||||
| 
 | ||||
| class BertIntent(): | ||||
|     def __init__(self): | ||||
|         self.classifier = pipeline("text-classification", | ||||
|                                    model="Tjiho/french-intents-classificaton") | ||||
| 
 | ||||
|     def predict(self, sentence): | ||||
|         # sentence = self.simple_sentence_corrector(sentence) | ||||
|         classification = self.classifier(sentence) | ||||
| 
 | ||||
|         if classification[0]["score"] < 0.7:  # score too low | ||||
|             return | ||||
| 
 | ||||
|         label = classification[0]["label"] | ||||
| 
 | ||||
|         if label == "HEURE": | ||||
|             return {'intentName': 'GetTime'} | ||||
|         elif label == "DATE": | ||||
|             return {'intentName': 'GetDate'} | ||||
|         elif label == "ETEINDRE_CUISINE": | ||||
|             return {'intentName': 'GetTime', 'intentArg': ['cuisine']} | ||||
|         elif label == "ETEINDRE_BUREAU": | ||||
|             return {'intentName': 'LightOff', 'intentArg': ['bureau']} | ||||
|         elif label == "ETEINDRE_SALON": | ||||
|             return {'intentName': 'LightOff', 'intentArg': ['salon']} | ||||
|         elif label == "ETEINDRE_CHAMBRE": | ||||
|             return {'intentName': 'LightOff', 'intentArg': ['chambre']} | ||||
|         elif label == "ALLUMER_CUISINE": | ||||
|             return {'intentName': 'LightOn', 'intentArg': ['cuisine']} | ||||
|         elif label == "ALLUMER_SALON": | ||||
|             return {'intentName': 'LightOn', 'intentArg': ['salon']} | ||||
|         elif label == "ALLUMER_BUREAU": | ||||
|             return {'intentName': 'LightOn', 'intentArg': ['bureau']} | ||||
|         elif label == "ALLUMER_CHAMBRE": | ||||
|             return {'intentName': 'LightOn', 'intentArg': ['chambre']} | ||||
|         elif label == "METEO": | ||||
|             return {'intentName': 'Meteo'} | ||||
|         elif label == "TEMPERATURE_EXTERIEUR": | ||||
|             return {'intentName': 'Temperature_ext'} | ||||
|         elif label == "TEMPERATURE_INTERIEUR": | ||||
|             return {'intentName': 'Temperature_int'} | ||||
| 
 | ||||
| 
 | ||||
| class TestAlexa(unittest.TestCase): | ||||
|     @classmethod | ||||
|  | @ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase): | |||
|         self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton') | ||||
|         self.assertEqual(res['entities'][0]['word'], '▁cuisine') | ||||
| 
 | ||||
| 
 | ||||
|     def test_bad_transcribe(self): | ||||
|         res = self.alexa.predict("dépeint la cuisine") | ||||
|         self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff') | ||||
|         self.assertEqual(res['entities'][0]['word'], '▁cuisine') | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
							
								
								
									
										4
									
								
								src/tools/simple_sentence_corrector.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								src/tools/simple_sentence_corrector.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,4 @@ | |||
| def simple_sentence_corrector(sentence): | ||||
|     sentence = sentence.replace('étant', 'éteins') | ||||
|     sentence = sentence.replace('dépeint', 'éteins') | ||||
|     return sentence | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue