import string
from datetime import datetime
import nltk
import sklearn_crfsuite
from sklearn.model_selection import train_test_split
from sklearn_crfsuite import metrics
from ..argument_pipeline.base import Model
from ..corpora import load_essay_corpus
from ..corpora._essay_corpus import tokenize_essay_sentences
from ..nlp import Lemmatiser
from ..nlp._utils import nltk_download
from ..utils import logger, get_is_dev
lemmatiser = Lemmatiser()
__all__ = [
"ArgumentSegmenter"
]
[docs]class ArgumentSegmenter(Model):
"""Argument Segmenter using Conditional Random Fields.
Examples
--------
>>> import canary
>>> segmenter = canary.load_model("arg_segmenter")
>>> sentence = "To sum up, I believe that a higher pay can be one of the incentive if were to encourage harder-working employees."
>>> print(segmenter.predict(sentence))
[('To', 'O'), ('sum', 'O'), ('up', 'O'), (',', 'O'), ('I', 'O'), ('believe', 'O'), ('that', 'O'), ('a', 'Arg-B'), ('higher', 'Arg-I'), ('pay', 'Arg-I'), ('can', 'Arg-I'), ('be', 'Arg-I'), ('one', 'Arg-I'), ('of', 'Arg-I'), ('the', 'Arg-I'), ('incentive', 'Arg-I'), ('if', 'Arg-I'), ('were', 'Arg-I'), ('to', 'Arg-I'), ('encourage', 'Arg-I'), ('harder-working', 'Arg-I'), ('employees', 'Arg-I'), ('.', 'O')]
"""
def __init__(self, model_id=None):
if model_id is None:
model_id = "arg_segmenter"
super().__init__(
model_id=model_id,
)
[docs] @staticmethod
def default_train():
# Need to get data into a usable shape
logger.debug("Getting training data")
x, y = load_essay_corpus(
purpose="sequence_labelling"
)
train_data, test_data, train_targets, test_targets = \
train_test_split(x, y,
train_size=0.8,
shuffle=True,
random_state=0,
)
logger.debug("Getting training features")
train_data = [_get_sentence_features(s) for s in train_data]
logger.debug("Getting training labels")
train_targets = [_get_labels(s) for s in train_targets]
logger.debug("Getting test features")
test_data = [_get_sentence_features(s) for s in test_data]
logger.debug("Getting test labels")
test_targets = [_get_labels(s) for s in test_targets]
return train_data, test_data, train_targets, test_targets
[docs] @classmethod
def train(cls, pipeline_model=None, train_data=None, test_data=None, train_targets=None, test_targets=None,
save_on_finish=True, *args, **kwargs):
model = cls()
if any(item is None for item in [train_data, test_data, train_targets, test_targets]):
# get default data if the above is not present
train_data, test_data, train_targets, test_targets = model.default_train()
logger.debug("Training algorithm")
if pipeline_model is None:
pipeline_model = sklearn_crfsuite.CRF(
algorithm='l2sgd',
all_possible_transitions=True,
all_possible_states=True,
)
model.set_model(pipeline_model)
model.fit(train_data, train_targets)
labels = list(pipeline_model.classes_)
y_pred = pipeline_model.predict(test_data)
sorted_labels = sorted(
labels,
key=lambda name: (name[1:], name[0])
)
logger.debug("\n\n" + metrics.flat_classification_report(
test_targets, y_pred, labels=sorted_labels, digits=4
))
model._metrics = metrics.flat_classification_report(
test_targets, y_pred, labels=sorted_labels, digits=4, output_dict=True
)
if get_is_dev() is True:
from ._utils import log_training_data
log_training_data({"result": metrics.flat_classification_report(
test_targets, y_pred, labels=sorted_labels, digits=4, output_dict=True
),
"datetime": str(datetime.now()), "model": model.__class__.__name__})
if save_on_finish is True:
model.save()
return model
[docs] def predict(self, data, probability=False, binary=False):
nltk_download(['punkt', 'averaged_perceptron_tagger'])
if probability is True:
logger.warn(
f"{self.__class__.__name__} does not support probability predictions. This parameter is ignored.")
data_type = type(data)
if data_type is str:
tokens = nltk.word_tokenize(data)
data = [_get_sentence_features(tokens)]
predictions = super().predict(data, probability=False)[0]
if binary is not None:
if binary is True:
if all(k == "O" for k in super().predict(data, probability=False)[0]):
return False
else:
return True
return list(zip(tokens, predictions))
if data_type is list:
if all(type(item) is dict for item in data) is False:
logger.error("The list passed in needs to only contain dictionary features")
return
return super().predict(data, probability=False)
[docs] def get_components_from_document(self, document: str) -> list:
"""Helper method which extracts components from a document which have been identified as argument
spans.
Parameters
----------
document: str
The document which is to be analysed.
Returns
-------
list
A list of dictionary items which detail the components that have been identified.
"""
from nltk.tokenize.treebank import TreebankWordDetokenizer
detokenizer = TreebankWordDetokenizer()
# Segment from full text
components = []
current_component = []
sentences = tokenize_essay_sentences(document)
if len(sentences) < 2:
logger.warn("There doesn't seem to be much to analyse in the document.")
predictions = [self.predict(sentence) for sentence in sentences]
# @TODO Ensure this works properly
for prediction in predictions:
for i, token in enumerate(prediction):
if token[1] == "Arg-B":
current_component = [token[0]]
elif token[1] == "Arg-I":
current_component.append(token[0])
if i < len(prediction):
if prediction[i + 1][1] == "O":
components.append(current_component)
current_component = []
# Delete these
del current_component
del prediction
del predictions
logger.debug(f"{len(components)} components found from segmenter.")
# Get covering sentences
for i, component in enumerate(components):
for sen in sentences:
if detokenizer.detokenize(component) in sen or all(x in nltk.word_tokenize(sen) for x in component):
components[i] = {
'component_ref': i,
"cover_sentence": sen,
"component": detokenizer.detokenize(component),
"len_component": len(component),
"len_cover_sen": len(nltk.word_tokenize(sen)),
"tokens": component
}
split = components[i]['cover_sentence'].split((components[i]['component']))
try:
components[i].update({
'n_following_comp_tokens': len(nltk.word_tokenize(split[0])),
'n_preceding_comp_tokens': len(nltk.word_tokenize(split[1])),
})
except IndexError:
# kind of hackish but detokenising isn't a perfect process
cs = detokenizer.detokenize(nltk.word_tokenize(components[i]['cover_sentence']))
split = cs.split((components[i]['component']))
try:
components[i].update({
'n_following_comp_tokens': len(nltk.word_tokenize(split[0])),
'n_preceding_comp_tokens': len(nltk.word_tokenize(split[1])),
})
except IndexError as e:
# @TODO Fix this bit
logger.error(e)
components[i].update({
'n_following_comp_tokens': 0,
'n_preceding_comp_tokens': 0,
})
paragraphs = [p.strip() for p in document.split("\n") if p and not p.isspace()]
logger.debug(f"{len(paragraphs)} paragraphs in document.")
if not all('tokens' in c for c in components):
raise KeyError("There was an error finding argumentative components")
for i, component in enumerate(components):
for j, para in enumerate(paragraphs):
if detokenizer.detokenize(component['tokens']) in para or all(
x in nltk.word_tokenize(para) for x in component['tokens']):
components[i].update({
'len_paragraph': len(nltk.word_tokenize(para)),
'para_ref': j + 1,
'is_in_intro': True if j == (0 or 1) else False,
'is_in_conclusion': True if j == (len(paragraphs) - 1) else False,
})
# find n_following_components and n_preceding_components
for component in components:
if 'para_ref' not in component:
raise ValueError("failed to find ...")
neighbouring_components = [c for c in components if
c['para_ref'] == component['para_ref'] and c != component]
if len(neighbouring_components) < 2:
component['n_following_components'] = 0
component['n_preceding_components'] = 0
component['component_position'] = 1
component['first_in_paragraph'] = True
component['last_in_paragraph'] = True
else:
component['n_preceding_components'] = len(
[c for c in neighbouring_components if c['component_ref'] < component['component_ref']])
component['n_following_components'] = len(
[c for c in neighbouring_components if c['component_ref'] > component['component_ref']])
component['component_position'] = (len(neighbouring_components) - component[
'n_following_components']) + 1
component['first_in_paragraph'] = True if component["n_preceding_components"] == 0 else False
component['last_in_paragraph'] = True if component["n_following_components"] == 0 else False
# find if indicator type is...
from canary.nlp.transformers import DiscourseMatcher
forward_matcher = DiscourseMatcher('forward')
thesis_matcher = DiscourseMatcher('thesis')
rebuttal_matcher = DiscourseMatcher('rebuttal')
backward_matcher = DiscourseMatcher('backward')
for i, component in enumerate(components):
para_components = [c for c in components if c['para_ref'] == component['para_ref']]
prev_components = para_components[:i]
following_components = para_components[i + 1:]
component['indicator_type_precedes_component'] = False
component['indicator_type_follows_component'] = False
for c in prev_components:
if forward_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_precedes_component'] = True
break
elif thesis_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_precedes_component'] = True
break
elif rebuttal_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_precedes_component'] = True
break
elif backward_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_precedes_component'] = True
break
for c in following_components:
if forward_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_follows_component'] = True
break
elif thesis_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_follows_component'] = True
break
elif rebuttal_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_follows_component'] = True
break
elif backward_matcher.transform(c['cover_sentence'])[0][0] is True:
component['indicator_type_follows_component'] = True
break
for c in components:
del c['tokens']
return components
def _get_word_features(sent, i):
word = sent[i][0]
features = {
'bias': 1.0,
'word.lower()': str(word).lower(),
'word.is_lower': word.isupper(),
'word.istitle()': word.istitle(),
'word.isdigit()': word.isdigit(),
'postag': sent[i][1],
'ent': sent[i][2],
'lemma': lemmatiser(word)[0],
"period": word == '.',
'is_punct': word in string.punctuation,
'len': len(word)
}
if i > 0:
word1 = sent[i - 1][0]
postag1 = sent[i - 1][1]
ent1 = sent[i - 1][2]
features.update({
'-1:word.lower()': str(word1).lower(),
'-1:word.istitle()': word1.istitle(),
'-1:word.isupper()': word1.isupper(),
'-1:postag': postag1,
'-1:lemma': lemmatiser(word1)[0],
'-1:ent': ent1,
'-1:punct': word1 in string.punctuation
})
else:
features['BOS'] = True
if i > 1:
word2 = sent[i - 2][0]
postag2 = sent[i - 2][1]
ent2 = sent[i - 2][2]
features.update({
'-2:word.lower()': str(word2).lower(),
'-2:word.istitle()': word2.istitle(),
'-2:word.isupper()': word2.isupper(),
'-2:postag': postag2,
'-2:lemma': lemmatiser(word2)[0],
'-2:ent': ent2,
})
if i > 2:
word3 = sent[i - 3][0]
postag3 = sent[i - 3][1]
ent3 = sent[i - 3][2]
features.update({
'-3:word.lower()': str(word3).lower(),
'-3:word.istitle()': word3.istitle(),
'-3:word.isupper()': word3.isupper(),
'-3:postag': postag3,
'-3:lemma': lemmatiser(word3)[0],
'-3:ent': ent3,
})
if i > 3:
word4 = sent[i - 4][0]
postag4 = sent[i - 4][1]
ent4 = sent[i - 4][2]
features.update({
'-4:word.lower()': str(word4).lower(),
'-4:word.istitle()': word4.istitle(),
'-4:word.isupper()': word4.isupper(),
'-4:postag': postag4,
'-4:lemma': lemmatiser(word4)[0],
'-4:ent': ent4,
})
if i > 4:
word5 = sent[i - 5][0]
postag5 = sent[i - 5][1]
ent5 = sent[i - 5][2]
features.update({
'-5:word.lower()': str(word5).lower(),
'-5:word.istitle()': word5.istitle(),
'-5:word.isupper()': word5.isupper(),
'-5:postag': postag5,
'-5:lemma': lemmatiser(word5)[0],
'-5:ent': ent5,
})
if i < len(sent) - 5:
word5 = sent[i + 5][0]
postag5 = sent[i + 5][1]
ent5 = sent[i + 5][2]
features.update({
'+5:word.lower()': str(word5).lower(),
'+5:word.istitle()': word5.istitle(),
'+5:word.isupper()': word5.isupper(),
'+5:postag': postag5,
'+5:lemma': lemmatiser(word5)[0],
'+5:ent': ent5,
})
if i < len(sent) - 4:
word4 = sent[i + 4][0]
postag4 = sent[i + 4][1]
ent4 = sent[i + 4][2]
features.update({
'+4:word.lower()': str(word4).lower(),
'+4:word.istitle()': word4.istitle(),
'+4:word.isupper()': word4.isupper(),
'+4:postag': postag4,
'+4:lemma': lemmatiser(word4)[0],
'+4:ent': ent4,
})
if i < len(sent) - 3:
word3 = sent[i + 3][0]
postag3 = sent[i + 3][1]
ent3 = sent[i + 3][2]
features.update({
'+3:word.lower()': str(word3).lower(),
'+3:word.istitle()': word3.istitle(),
'+3:word.isupper()': word3.isupper(),
'+3:postag': postag3,
'+3:lemma': lemmatiser(word3)[0],
'+3:ent': ent3,
})
if i < len(sent) - 2:
word2 = sent[i + 2][0]
postag2 = sent[i + 2][1]
ent2 = sent[i + 2][2]
features.update({
'+2:word.lower()': str(word2).lower(),
'+2:word.istitle()': word2.istitle(),
'+2:word.isupper()': word2.isupper(),
'+2:postag': postag2,
'+2:lemma': lemmatiser(word2)[0],
'+2:ent': ent2,
})
if i < len(sent) - 1:
word1 = sent[i + 1][0]
postag1 = sent[i + 1][1]
ent1 = sent[i + 1][2]
features.update({
'+1:word.lower()': str(word1).lower(),
'+1:word.istitle()': word1.istitle(),
'+1:word.isupper()': word1.isupper(),
'+1:postag': postag1,
'+1:ent': ent1,
'+1:lemma': lemmatiser(word1)[0],
'+1:punt': word1 in string.punctuation,
})
else:
features['EOS'] = True
return features
def _get_sentence_features(sent):
sent = _chunk(sent)
return [_get_word_features(sent, i) for i in range(len(sent))]
def _get_labels(sent):
return [label for label in sent]
def _chunk(sen):
nltk_download(['averaged_perceptron_tagger', 'maxent_ne_chunker', 'words', 'punkt'])
from nltk.chunk import tree2conlltags
return tree2conlltags(nltk.ne_chunk(nltk.pos_tag(sen)))