ratatouille/src/intent.py

129 lines
5 KiB
Python

import unittest
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
from transformers import pipeline
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
class AlexaIntent():
def __init__(self):
self.get_intents = self.init_intent_classification()
self.get_entities = self.init_entities_classification()
def init_intent_classification(self):
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier = TextClassificationPipeline(
model=model, tokenizer=tokenizer)
return classifier
def init_entities_classification(self):
tokenizer = AutoTokenizer.from_pretrained(
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
model = AutoModelForTokenClassification.from_pretrained(
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
return predict
def simple_sentence_corrector(self, sentence):
sentence = sentence.replace('étant', 'éteins')
sentence = sentence.replace('dépeint', 'éteins')
return sentence
def intent_corrector(self, intents):
for intent in intents:
if intent['label'] in DOMOTIQUE_OBJ_ON:
intent['label'] = 'iot_hue_lighton'
if intent['label'] in DOMOTIQUE_OBJ_OFF:
intent['label'] = 'iot_hue_lightoff'
return intents
def predict(self, sentence):
return {
'intents': self.intent_corrector(self.get_intents(sentence)),
'entities': self.get_entities(sentence)
}
class BertIntent():
def __init__(self):
self.classifier = pipeline("text-classification",
model="Tjiho/french-intents-classificaton")
self.ner_tokenizer = AutoTokenizer.from_pretrained(
"Jean-Baptiste/camembert-ner")
self.ner_model = AutoModelForTokenClassification.from_pretrained(
"Jean-Baptiste/camembert-ner")
self.entity_recognition = pipeline(
'ner', model=self.ner_model, tokenizer=self.ner_tokenizer, aggregation_strategy="simple")
def predict(self, sentence):
# sentence = self.simple_sentence_corrector(sentence)
classification = self.classifier(sentence)
if classification[0]["score"] < 0.7: # score too low
return self.looking_for_entity(sentence)
label = classification[0]["label"]
if label == "HEURE":
return {'intentName': 'GetTime'}
elif label == "DATE":
return {'intentName': 'GetDate'}
elif label == "ETEINDRE_CUISINE":
return {'intentName': 'LightOff', 'intentArg': ['cuisine']}
elif label == "ETEINDRE_BUREAU":
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
elif label == "ETEINDRE_SALON":
return {'intentName': 'LightOff', 'intentArg': ['salon']}
elif label == "ETEINDRE_CHAMBRE":
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
elif label == "ALLUMER_CUISINE":
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
elif label == "ALLUMER_SALON":
return {'intentName': 'LightOn', 'intentArg': ['salon']}
elif label == "ALLUMER_BUREAU":
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
elif label == "ALLUMER_CHAMBRE":
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
elif label == "METEO":
return {'intentName': 'Meteo'}
elif label == "TEMPERATURE_EXTERIEUR":
return {'intentName': 'Temperature_ext'}
elif label == "TEMPERATURE_INTERIEUR":
return {'intentName': 'Temperature_int'}
def looking_for_entity(self, sentence):
entities = self.entity_recognition(sentence)
if len(entities) > 0:
return {'intentName': 'search', "intentArg": [entities[0]['word']]}
else:
return {'intentName': ''}
class TestAlexa(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.alexa = AlexaIntent()
def test_lampe(self):
res = self.alexa.predict("éteins la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
res = self.alexa.predict("allume la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
def test_bad_transcribe(self):
res = self.alexa.predict("dépeint la cuisine")
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
if __name__ == '__main__':
unittest.main()