129 lines
5 KiB
Python
129 lines
5 KiB
Python
import unittest
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
|
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
|
|
from transformers import pipeline
|
|
|
|
DOMOTIQUE_OBJ_ON = ['iot_wemo_on']
|
|
DOMOTIQUE_OBJ_OFF = ['iot_wemo_off']
|
|
|
|
|
|
class AlexaIntent():
|
|
|
|
def __init__(self):
|
|
self.get_intents = self.init_intent_classification()
|
|
self.get_entities = self.init_entities_classification()
|
|
|
|
def init_intent_classification(self):
|
|
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
classifier = TextClassificationPipeline(
|
|
model=model, tokenizer=tokenizer)
|
|
return classifier
|
|
|
|
def init_entities_classification(self):
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
|
model = AutoModelForTokenClassification.from_pretrained(
|
|
'qanastek/XLMRoberta-Alexa-Intents-NER-NLU')
|
|
predict = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
|
return predict
|
|
|
|
def simple_sentence_corrector(self, sentence):
|
|
sentence = sentence.replace('étant', 'éteins')
|
|
sentence = sentence.replace('dépeint', 'éteins')
|
|
return sentence
|
|
|
|
def intent_corrector(self, intents):
|
|
for intent in intents:
|
|
if intent['label'] in DOMOTIQUE_OBJ_ON:
|
|
intent['label'] = 'iot_hue_lighton'
|
|
if intent['label'] in DOMOTIQUE_OBJ_OFF:
|
|
intent['label'] = 'iot_hue_lightoff'
|
|
return intents
|
|
|
|
def predict(self, sentence):
|
|
return {
|
|
'intents': self.intent_corrector(self.get_intents(sentence)),
|
|
'entities': self.get_entities(sentence)
|
|
}
|
|
|
|
|
|
class BertIntent():
|
|
def __init__(self):
|
|
self.classifier = pipeline("text-classification",
|
|
model="Tjiho/french-intents-classificaton")
|
|
|
|
self.ner_tokenizer = AutoTokenizer.from_pretrained(
|
|
"Jean-Baptiste/camembert-ner")
|
|
self.ner_model = AutoModelForTokenClassification.from_pretrained(
|
|
"Jean-Baptiste/camembert-ner")
|
|
self.entity_recognition = pipeline(
|
|
'ner', model=self.ner_model, tokenizer=self.ner_tokenizer, aggregation_strategy="simple")
|
|
|
|
def predict(self, sentence):
|
|
# sentence = self.simple_sentence_corrector(sentence)
|
|
classification = self.classifier(sentence)
|
|
|
|
if classification[0]["score"] < 0.7: # score too low
|
|
return self.looking_for_entity(sentence)
|
|
|
|
label = classification[0]["label"]
|
|
|
|
if label == "HEURE":
|
|
return {'intentName': 'GetTime'}
|
|
elif label == "DATE":
|
|
return {'intentName': 'GetDate'}
|
|
elif label == "ETEINDRE_CUISINE":
|
|
return {'intentName': 'LightOff', 'intentArg': ['cuisine']}
|
|
elif label == "ETEINDRE_BUREAU":
|
|
return {'intentName': 'LightOff', 'intentArg': ['bureau']}
|
|
elif label == "ETEINDRE_SALON":
|
|
return {'intentName': 'LightOff', 'intentArg': ['salon']}
|
|
elif label == "ETEINDRE_CHAMBRE":
|
|
return {'intentName': 'LightOff', 'intentArg': ['chambre']}
|
|
elif label == "ALLUMER_CUISINE":
|
|
return {'intentName': 'LightOn', 'intentArg': ['cuisine']}
|
|
elif label == "ALLUMER_SALON":
|
|
return {'intentName': 'LightOn', 'intentArg': ['salon']}
|
|
elif label == "ALLUMER_BUREAU":
|
|
return {'intentName': 'LightOn', 'intentArg': ['bureau']}
|
|
elif label == "ALLUMER_CHAMBRE":
|
|
return {'intentName': 'LightOn', 'intentArg': ['chambre']}
|
|
elif label == "METEO":
|
|
return {'intentName': 'Meteo'}
|
|
elif label == "TEMPERATURE_EXTERIEUR":
|
|
return {'intentName': 'Temperature_ext'}
|
|
elif label == "TEMPERATURE_INTERIEUR":
|
|
return {'intentName': 'Temperature_int'}
|
|
|
|
def looking_for_entity(self, sentence):
|
|
entities = self.entity_recognition(sentence)
|
|
if len(entities) > 0:
|
|
return {'intentName': 'search', "intentArg": [entities[0]['word']]}
|
|
else:
|
|
return {'intentName': ''}
|
|
|
|
|
|
class TestAlexa(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.alexa = AlexaIntent()
|
|
|
|
def test_lampe(self):
|
|
res = self.alexa.predict("éteins la cuisine")
|
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
|
|
|
res = self.alexa.predict("allume la cuisine")
|
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lighton')
|
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
|
|
|
def test_bad_transcribe(self):
|
|
res = self.alexa.predict("dépeint la cuisine")
|
|
self.assertEqual(res['intents'][0]['label'], 'iot_hue_lightoff')
|
|
self.assertEqual(res['entities'][0]['word'], '▁cuisine')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|