Add fuzzy and text-classification light model #1

Merged
Tjiho merged 18 commits from fuzzy into master 2025-02-01 01:35:11 +01:00
3 changed files with 115 additions and 22 deletions
Showing only changes of commit 07c5184f13 - Show all commits

65
main.py
View file

@ -1,9 +1,9 @@
import logging import logging
import time import argparse
import schedule
# import schedule
import config import config
from src.rhasspy import rhasspy_mqtt as yoda_listener from src.rhasspy import rhasspy_mqtt as yoda_listener
from src.rhasspy import Rhasspy from src.rhasspy import Rhasspy
from src.ratatouille import Ratatouille from src.ratatouille import Ratatouille
@ -11,30 +11,73 @@ from src.mpd import Mpd
from src.hass import Hass from src.hass import Hass
from src.httpServer import get_server from src.httpServer import get_server
from src.fuzzy import fuzz_predict from src.fuzzy import fuzz_predict
# from src.intent import AlexaIntent from src.intent import BertIntent
from src.tools.simple_sentence_corrector import simple_sentence_corrector
# --------- setup args -------------
parser = argparse.ArgumentParser(
prog='Ratatouille',
description='Ratatouille le cerveau domotique !')
parser.add_argument('mode')
parser.add_argument('-i', '--ip', required=False)
parser.add_argument('-p', '--port', required=False, type=int)
args = parser.parse_args()
if args.mode == "server":
if args.ip is None or args.port is None:
logging.error(" --ip or --port argument missing")
exit()
# -------- setup logging ------------
logging.basicConfig( logging.basicConfig(
level=10, level=10,
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s" format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
) )
IP = "10.10.10.8" logging.info("Loading ratatouilles modules")
PORT = 5555
# ---------- other ---------------
walle = Hass(config.hass_url, config.hass_token) walle = Hass(config.hass_url, config.hass_token)
yoda = None # Rhasspy(config.rhasspy_url) yoda = None # Rhasspy(config.rhasspy_url)
mopidy = Mpd('10.10.10.8') mopidy = Mpd('10.10.10.8')
ratatouille = Ratatouille(yoda, walle, mopidy, schedule) ratatouille = Ratatouille(yoda, walle, mopidy, None)
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset # alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
bert = BertIntent()
def answer(sentence): def answer(sentence):
# return ratatouille.parse_alexa(alexa.predict(sentence)) # return ratatouille.parse_alexa(alexa.predict(sentence))
return ratatouille.parse_fuzzy(fuzz_predict(sentence)) sentence_corrected = simple_sentence_corrector(sentence)
prediction = bert.predict(sentence_corrected)
return ratatouille.parse_fuzzy(prediction)
# return "42" # return "42"
server = get_server(IP, PORT, answer) def run_server(ip, port):
server = get_server(ip, port, answer)
logging.info('Running server on '+ip+':'+str(port))
server.serve_forever()
logging.info('Running server on '+IP+':'+str(PORT))
server.serve_forever() def run_prompt():
question = "empty"
while question != "stop":
question = input("?")
print(answer(question))
logging.info("Ratatouille is ready !")
# run_server()
if args.mode == "server":
run_server(str(args.ip), args.port)
else:
run_prompt()

View file

@ -1,12 +1,14 @@
import unittest import unittest
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
from transformers import pipeline
DOMOTIQUE_OBJ_ON = ['iot_wemo_on'] DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off'] DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
class AlexaIntent(): class AlexaIntent():
def __init__(self): def __init__(self):
self.get_intents = self.init_intent_classification() self.get_intents = self.init_intent_classification()
self.get_entities = self.init_entities_classification() self.get_entities = self.init_entities_classification()
@ -15,21 +17,24 @@ class AlexaIntent():
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification' model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer) classifier = TextClassificationPipeline(
model=model, tokenizer=tokenizer)
return classifier return classifier
def init_entities_classification(self): def init_entities_classification(self):
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') tokenizer = AutoTokenizer.from_pretrained(
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU') 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
model = AutoModelForTokenClassification.from_pretrained(
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer) predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
return predict return predict
def simple_sentence_corrector(self,sentence): def simple_sentence_corrector(self, sentence):
sentence = sentence.replace('étant','éteins') sentence = sentence.replace('étant', 'éteins')
sentence = sentence.replace('dépeint','éteins') sentence = sentence.replace('dépeint', 'éteins')
return sentence return sentence
def intent_corrector(self,intents): def intent_corrector(self, intents):
for intent in intents: for intent in intents:
if intent['label'] in DOMOTIQUE_OBJ_ON: if intent['label'] in DOMOTIQUE_OBJ_ON:
intent['label'] = 'iot_hue_lighton' intent['label'] = 'iot_hue_lighton'
@ -37,7 +42,7 @@ class AlexaIntent():
intent['label'] = 'iot_hue_lightoff' intent['label'] = 'iot_hue_lightoff'
return intents return intents
def predict(self,sentence): def predict(self, sentence):
sentence = self.simple_sentence_corrector(sentence) sentence = self.simple_sentence_corrector(sentence)
return { return {
'intents': self.intent_corrector(self.get_intents(sentence)), 'intents': self.intent_corrector(self.get_intents(sentence)),
@ -45,6 +50,47 @@ class AlexaIntent():
} }
class BertIntent():
def __init__(self):
self.classifier = pipeline("text-classification",
model="Tjiho/french-intents-classificaton")
def predict(self, sentence):
# sentence = self.simple_sentence_corrector(sentence)
classification = self.classifier(sentence)
if classification[0]["score"] < 0.7: # score too low
return
label = classification[0]["label"]
if label == "HEURE":
return {'intentName': 'GetTime'}
elif label == "DATE":
return {'intentName': 'GetDate'}
elif label == "ETEINDRE_CUISINE":
return {'intentName': 'GetTime', 'intentArg': ['cuisine']}
elif label == "ETEINDRE_BUREAU":
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
elif label == "ETEINDRE_SALON":
return {'intentName': 'LightOff', 'intentArg': ['salon']}
elif label == "ETEINDRE_CHAMBRE":
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
elif label == "ALLUMER_CUISINE":
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
elif label == "ALLUMER_SALON":
return {'intentName': 'LightOn', 'intentArg': ['salon']}
elif label == "ALLUMER_BUREAU":
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
elif label == "ALLUMER_CHAMBRE":
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
elif label == "METEO":
return {'intentName': 'Meteo'}
elif label == "TEMPERATURE_EXTERIEUR":
return {'intentName': 'Temperature_ext'}
elif label == "TEMPERATURE_INTERIEUR":
return {'intentName': 'Temperature_int'}
class TestAlexa(unittest.TestCase): class TestAlexa(unittest.TestCase):
@classmethod @classmethod
@ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase):
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton') self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
self.assertEqual(res['entities'][0]['word'], '▁cuisine') self.assertEqual(res['entities'][0]['word'], '▁cuisine')
def test_bad_transcribe(self): def test_bad_transcribe(self):
res = self.alexa.predict("dépeint la cuisine") res = self.alexa.predict("dépeint la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff') self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine') self.assertEqual(res['entities'][0]['word'], '▁cuisine')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -0,0 +1,4 @@
def simple_sentence_corrector(sentence):
sentence = sentence.replace('étant', 'éteins')
sentence = sentence.replace('dépeint', 'éteins')
return sentence