Source code for canary.argument_pipeline

"""The argument_pipeline package contains the functionality required to analyse a unstructured document
of natural language.
"""
import glob
import itertools
import os
import zipfile
from pathlib import Path

import joblib
import requests

from ._utils import models_available_on_disk, get_models_not_on_disk, get_downloadable_assets_from_github, \
    log_training_data

__all__ = [
    "load_model",
    "download_model",
    "analyse",
    "analyse_file",
]


[docs]def download_model(model: str, download_to: str = None, overwrite=False): """Downloads the pretrained Canary models from a GitHub. Parameters ---------- model: str The model ID to download. download_to: str Where to download the model to. overwrite: bool, default=False Should Canary overwrite existing models if they are already present. """ import canary from canary.utils import CANARY_MODEL_DOWNLOAD_LOCATION, CANARY_MODEL_STORAGE_LOCATION # Make sure the canary local directory(s) exist beforehand os.makedirs(CANARY_MODEL_STORAGE_LOCATION, exist_ok=True) if download_to is None: download_to = Path(CANARY_MODEL_STORAGE_LOCATION) if type(model) not in [str, list]: raise ValueError("Model value should be a string or a list") def unzip_model(model_zip: str): with zipfile.ZipFile(model_zip, "r") as zf: zf.extractall(download_to) canary.utils.logger.info("models extracted.") os.remove(model_zip) def download_asset(asset_dict): if 'url' in asset_dict: asset_res = requests.get(asset_dict['url'], headers={ "Accept": "application/octet-stream" }, stream=True) if asset_res.status_code == 200: file = download_to / asset_dict['name'] with open(file, "wb") as f: f.write(asset_res.raw.read()) canary.utils.logger.info("downloaded model") if file.suffix == ".zip": unzip_model(download_to / asset_dict['name']) if file.suffix == ".joblib": canary.utils.logger.info(f"{file.name} downloaded to {file.parent}") else: canary.utils.logger.error("There was an error downloading the asset.") return # check if we have already have the model downloaded if overwrite is False and model != "all": models = glob.glob(str(download_to / "*.joblib")) if len(models) > 0: for model in models: if os.path.isfile(model) is True: canary.utils.logger.warn(f"{Path(model).stem} already present: {model}") return github_releases = get_downloadable_assets_from_github() models_on_disk = models_available_on_disk() if github_releases is not None: # parse JSON response if 'assets' in github_releases: if len(github_releases['assets']) > 0: for asset in github_releases['assets']: name = asset['name'].split(".")[0] if name in models_on_disk and overwrite is False: canary.utils.logger.warn(f"{name} already present.") continue if type(model) is str: if model != "all": if name == model: download_asset(asset) elif model == "all": download_asset(asset) if type(model) is list: if all(type(j) is str for j in model) is True: if name in model: download_asset(asset) else: raise ValueError("All items in the list should be strings.") else: canary.utils.logger.info("No assets to download.") return if 'prerelease' in github_releases: canary.utils.logger.info("This has been marked as a pre-release.") else: canary.utils.logger.error(f"There was an issue getting the models")
[docs]def analyse_file(file, min_link_confidence=0.8, min_support_confidence=0.8, min_attack_confidence=0.8): """Wrapper around the `analyse` function which takes in a file location as a string. Parameters ---------- file: str The absolute file path min_link_confidence: float, default=0.8 The minimum confidence needed for two arguments to be considered "linked" min_support_confidence: float, default=0.8 The minimum confidence needed to be classified as a support relation min_attack_confidence: float, default=0.8 The minimum confidence needed to be classified as an attack relation Returns ------- dict the SADFace document. Examples -------- >>> from canary import analyse_file >>> document = "/Users/my_user/doc.txt" >>> analysis = analyse_file(document) >>> analysis { "metadata": {...}, "resources": {...}, "nodes": {...}, "edges": {...} } Notes ----- Refer to https://github.com/ARG-ENU/SADFace """ if not os.path.isfile(file): raise TypeError("file argument should be a valid file") with open(file, "r", encoding='utf-8') as document: return analyse( document.read(), min_link_confidence=min_link_confidence, min_support_confidence=min_support_confidence, min_attack_confidence=min_attack_confidence, )
[docs]def analyse(document: str, min_link_confidence=0.8, min_support_confidence=0.8, min_attack_confidence=0.8): r"""Analyses a document . Parameters ---------- document: str The document text that is being analysed min_link_confidence: float, default=0.8 The minimum confidence needed for two arguments to be considered "linked" min_support_confidence: float, default=0.8 The minimum confidence needed to be classified as a support relation min_attack_confidence: float, default=0.8 The minimum confidence needed to be classified as an attack relation Returns ------- dict the SADFace document. Examples -------- >>> from canary import analyse >>> document_text = "..." >>> analysis = analyse(document_text, min_link_confidence=0.65) >>> analysis { "metadata": {...}, "resources": {...}, "nodes": {...}, "edges": {...} } Notes ----- Refer to https://github.com/ARG-ENU/SADFace """ from ..argument_pipeline.argument_segmentation import ArgumentSegmenter from ..argument_pipeline.component_prediction import ArgumentComponent from ..utils import logger from .. import __version__ from ..corpora import _essay_corpus logger.debug(document) if type(document) is not str: raise TypeError("The inputted document should be a string.") # load segmenter logger.debug("Loading Argument Segmenter") segmenter: ArgumentSegmenter = load_model('arg_segmenter') if segmenter is None: logger.error("Failed to load segmenter") raise ValueError("Could not load segmenter.") components = segmenter.get_components_from_document(document) if len(components) > 0: n_claims, n_major_claims, n_premises = 0, 0, 0 logger.debug("Loading component predictor.") component_predictor: ArgumentComponent = load_model('argument_component') if component_predictor is None: raise TypeError("Could not load argument component predictor") for component in components: component['type'] = component_predictor.predict(component) if component['type'] == 'Claim': n_claims += 1 elif component['type'] == "MajorClaim": n_major_claims += 1 elif component['type'] == "Premise": n_premises += 1 from canary.argument_pipeline.link_predictor import LinkPredictor link_predictor: LinkPredictor = load_model('link_predictor') if link_predictor is None: raise TypeError("Could not load link predictor") all_possible_component_pairs = [tuple(reversed(j)) for j in list(itertools.permutations(components, 2)) if j[0] != j[1]] logger.debug(f"{len(all_possible_component_pairs)} possible combinations.") sentences = _essay_corpus.tokenize_essay_sentences(document) for i, pair in enumerate(all_possible_component_pairs): arg1, arg2 = pair same_sentence = False for s in sentences: if arg1['component'] in s and arg2['component'] in s: same_sentence = True all_possible_component_pairs[i] = { "source_before_target": arg1['component_position'] > arg2['component_position'], "arg1_component": arg1["component"], "arg2_component": arg2["component"], "arg1_covering_sentence": arg1["cover_sentence"], "arg2_covering_sentence": arg2["cover_sentence"], "arg1_n_preceding_components": arg1['n_preceding_components'], "arg1_n_following_components": arg1['n_following_comp_tokens'], "arg2_n_preceding_components": arg2['n_preceding_components'], "arg2_n_following_components": arg2['n_following_comp_tokens'], "arg1_first_in_paragraph": arg1["first_in_paragraph"], "arg2_first_in_paragraph": arg2["first_in_paragraph"], "arg2_last_in_paragraph": arg2['last_in_paragraph'], "arg1_last_in_paragraph": arg1['last_in_paragraph'], "arg1_in_intro": arg1["is_in_intro"], "arg2_in_intro": arg2["is_in_intro"], "arg1_in_conclusion": arg1["is_in_conclusion"], "arg2_in_conclusion": arg2["is_in_conclusion"], "arg1_type": arg1["type"], "arg2_type": arg2["type"], "arg1_and_arg2_in_same_sentence": same_sentence, "n_para_components": arg1['n_following_components'] + arg2['n_preceding_components'], "arg1_position": arg1['component_position'], "arg2_position": arg2['component_position'], 'arg1_indicator_type_follows_component': arg1['indicator_type_follows_component'], 'arg2_indicator_type_follows_component': arg2['indicator_type_follows_component'], 'arg1_indicator_type_precedes_component': arg1['indicator_type_precedes_component'], 'arg2_indicator_type_precedes_component': arg2['indicator_type_precedes_component'] } args_linked = link_predictor.predict(all_possible_component_pairs[i]) == "Linked" all_possible_component_pairs[i].update({"args_linked": args_linked}) logger.debug("Done") linked_relations = [pair for pair in all_possible_component_pairs if pair["args_linked"] is True and link_predictor.predict(pair, probability=True)[ "Linked"] >= min_link_confidence] logger.debug( f" {len(linked_relations)} / {len(all_possible_component_pairs)} identified as being linked") # Find attack / support relations if len(linked_relations) > 0: from canary.argument_pipeline.structure_prediction import StructurePredictor sp: StructurePredictor = load_model('structure_predictor') for r in linked_relations: r['scheme'] = sp.predict(r) support_relations = [pair for pair in linked_relations if pair["scheme"] == 'supports'] attacks_relations = [pair for pair in linked_relations if pair["scheme"] == 'attacks'] logger.debug(f"Number of attack relations: {len(attacks_relations)}") logger.debug(f"Number of support relations: {len(support_relations)}") # Create a sadface document for what we have found from sadface import sadface sadface.initialise() sadface.set_title("Canary Analysis") # Create atoms for c in components: atom = sadface.add_atom(c['component']) sadface.add_atom_metadata(atom['id'], 'canary', 'type', c['type']) # Add edges for l in attacks_relations: arg1_id = sadface.get_atom_id(l['arg1_component']) arg2_id = sadface.get_atom_id(l['arg2_component']) sadface.add_conflict(arg_id=arg1_id, conflict_id=arg2_id) for l in support_relations: arg1_id = sadface.get_atom_id(l['arg1_component']) arg2_id = sadface.get_atom_id(l['arg2_component']) sadface.add_support(con_id=arg1_id, prem_id=[arg2_id]) # add some nice metadata sadface.add_global_metadata('canary', 'number_of_components', len(components)) sadface.add_global_metadata('canary', 'number_of_attack_relations', len(attacks_relations)) sadface.add_global_metadata('canary', 'number_of_support_relations', len(support_relations)) sadface.add_global_metadata('canary', 'number_of_linked_relations', len(linked_relations)) sadface.add_global_metadata('canary', 'number_of_premises', n_premises) sadface.add_global_metadata('canary', 'number_of_claims', n_claims) sadface.add_global_metadata('canary', 'number_of_major_claims', n_major_claims) sadface.add_global_metadata('canary', 'version', __version__) return sadface.get_document() else: import canary.utils canary.utils.logger.warn("Didn't find any evidence of argumentation")
[docs]def load_model(model_id: str, model_dir=None, download_if_missing=False): """Load a trained Canary model from disk. Parameters ---------- model_id: str The ID of the model to download model_dir: str Where the model should be loaded from download_if_missing: bool Should Canary attempt to download the model if it is not present on disk? Returns ------- Model The Canary model. Examples -------- >>> import canary >>> component_detector = canary.load_model("argument_component") >>> print(component_detector.__class__.__name__) """ import canary from canary.utils import CANARY_MODEL_DOWNLOAD_LOCATION, CANARY_MODEL_STORAGE_LOCATION original_id = None if ".joblib" not in model_id: original_id = model_id model_id = model_id + ".joblib" if model_dir: absolute_model_path = Path(model_dir) / model_id if os.path.isfile(absolute_model_path) is False: canary.utils.logger.warn("There does not appear to be a model here. This may fail.") else: absolute_model_path = Path(CANARY_MODEL_STORAGE_LOCATION) / model_id if os.path.isfile(absolute_model_path): canary.utils.logger.debug(f"Loading {model_id} from {absolute_model_path}") model = None try: model = joblib.load(absolute_model_path) except ModuleNotFoundError: canary.utils.logger.error("Could not load model. It may be out of date. " "Download the newer version or you can train the model yourself.") return model else: if download_if_missing is True: download_model(original_id) return load_model(model_id, model_dir, False) canary.utils.logger.error( f"Did not manage to load the model specified. Available models on disk are: {models_available_on_disk()}.") models_not_on_disk = get_models_not_on_disk() if len(models_not_on_disk) > 0: canary.utils.logger.error(f"Models available via download are: {models_not_on_disk}.")