Add fuzzy and text-classification light model #1

Merged
Tjiho merged 18 commits from fuzzy into master 2025-02-01 01:35:11 +01:00
3 changed files with 115 additions and 22 deletions
Showing only changes of commit 07c5184f13 - Show all commits

65
main.py
View file

@ -1,9 +1,9 @@
import logging
import time
import schedule
import argparse
# import schedule
import config
from src.rhasspy import rhasspy_mqtt as yoda_listener
from src.rhasspy import Rhasspy
from src.ratatouille import Ratatouille
@ -11,30 +11,73 @@ from src.mpd import Mpd
from src.hass import Hass
from src.httpServer import get_server
from src.fuzzy import fuzz_predict
# from src.intent import AlexaIntent
from src.intent import BertIntent
from src.tools.simple_sentence_corrector import simple_sentence_corrector
# --------- setup args -------------
parser = argparse.ArgumentParser(
prog='Ratatouille',
description='Ratatouille le cerveau domotique !')
parser.add_argument('mode')
parser.add_argument('-i', '--ip', required=False)
parser.add_argument('-p', '--port', required=False, type=int)
args = parser.parse_args()
if args.mode == "server":
if args.ip is None or args.port is None:
logging.error(" --ip or --port argument missing")
exit()
# -------- setup logging ------------
logging.basicConfig(
level=10,
format="%(asctime)s %(filename)s:%(lineno)s %(levelname)s %(message)s"
)
IP = "10.10.10.8"
PORT = 5555
logging.info("Loading ratatouilles modules")
# ---------- other ---------------
walle = Hass(config.hass_url, config.hass_token)
yoda = None # Rhasspy(config.rhasspy_url)
mopidy = Mpd('10.10.10.8')
ratatouille = Ratatouille(yoda, walle, mopidy, schedule)
ratatouille = Ratatouille(yoda, walle, mopidy, None)
# alexa = AlexaIntent() # we are not doing any request to the evil amazon but we are using one of its dataset
bert = BertIntent()
def answer(sentence):
# return ratatouille.parse_alexa(alexa.predict(sentence))
return ratatouille.parse_fuzzy(fuzz_predict(sentence))
sentence_corrected = simple_sentence_corrector(sentence)
prediction = bert.predict(sentence_corrected)
return ratatouille.parse_fuzzy(prediction)
# return "42"
server = get_server(IP, PORT, answer)
def run_server(ip, port):
server = get_server(ip, port, answer)
logging.info('Running server on '+ip+':'+str(port))
server.serve_forever()
logging.info('Running server on '+IP+':'+str(PORT))
server.serve_forever()
def run_prompt():
question = "empty"
while question != "stop":
question = input("?")
print(answer(question))
logging.info("Ratatouille is ready !")
# run_server()
if args.mode == "server":
run_server(str(args.ip), args.port)
else:
run_prompt()

View file

@ -1,12 +1,14 @@
import unittest
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
from transformers import pipeline
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
class AlexaIntent():
def __init__(self):
self.get_intents = self.init_intent_classification()
self.get_entities = self.init_entities_classification()
@ -15,21 +17,24 @@ class AlexaIntent():
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
classifier = TextClassificationPipeline(
model=model, tokenizer=tokenizer)
return classifier
def init_entities_classification(self):
tokenizer = AutoTokenizer.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
model = AutoModelForTokenClassification.from_pretrained('qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
tokenizer = AutoTokenizer.from_pretrained(
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
model = AutoModelForTokenClassification.from_pretrained(
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
return predict
def simple_sentence_corrector(self,sentence):
sentence = sentence.replace('étant','éteins')
sentence = sentence.replace('dépeint','éteins')
def simple_sentence_corrector(self, sentence):
sentence = sentence.replace('étant', 'éteins')
sentence = sentence.replace('dépeint', 'éteins')
return sentence
def intent_corrector(self,intents):
def intent_corrector(self, intents):
for intent in intents:
if intent['label'] in DOMOTIQUE_OBJ_ON:
intent['label'] = 'iot_hue_lighton'
@ -37,7 +42,7 @@ class AlexaIntent():
intent['label'] = 'iot_hue_lightoff'
return intents
def predict(self,sentence):
def predict(self, sentence):
sentence = self.simple_sentence_corrector(sentence)
return {
'intents': self.intent_corrector(self.get_intents(sentence)),
@ -45,6 +50,47 @@ class AlexaIntent():
}
class BertIntent():
def __init__(self):
self.classifier = pipeline("text-classification",
model="Tjiho/french-intents-classificaton")
def predict(self, sentence):
# sentence = self.simple_sentence_corrector(sentence)
classification = self.classifier(sentence)
if classification[0]["score"] < 0.7: # score too low
return
label = classification[0]["label"]
if label == "HEURE":
return {'intentName': 'GetTime'}
elif label == "DATE":
return {'intentName': 'GetDate'}
elif label == "ETEINDRE_CUISINE":
return {'intentName': 'GetTime', 'intentArg': ['cuisine']}
elif label == "ETEINDRE_BUREAU":
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
elif label == "ETEINDRE_SALON":
return {'intentName': 'LightOff', 'intentArg': ['salon']}
elif label == "ETEINDRE_CHAMBRE":
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
elif label == "ALLUMER_CUISINE":
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
elif label == "ALLUMER_SALON":
return {'intentName': 'LightOn', 'intentArg': ['salon']}
elif label == "ALLUMER_BUREAU":
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
elif label == "ALLUMER_CHAMBRE":
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
elif label == "METEO":
return {'intentName': 'Meteo'}
elif label == "TEMPERATURE_EXTERIEUR":
return {'intentName': 'Temperature_ext'}
elif label == "TEMPERATURE_INTERIEUR":
return {'intentName': 'Temperature_int'}
class TestAlexa(unittest.TestCase):
@classmethod
@ -60,11 +106,11 @@ class TestAlexa(unittest.TestCase):
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
def test_bad_transcribe(self):
res = self.alexa.predict("dépeint la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
if __name__ == '__main__':
unittest.main()
unittest.main()

View file

@ -0,0 +1,4 @@
def simple_sentence_corrector(sentence):
sentence = sentence.replace('étant', 'éteins')
sentence = sentence.replace('dépeint', 'éteins')
return sentence