Skip to content

Insights

get_annotation_labels(examples, case_sensitive=False)

Constructs a map of each annotation in the list of examples to each label that annotation has and references all examples associated with that label.

Parameters:

Name Type Description Default
examples List[Example]

Input examples

required
case_sensitive bool

Consider case of text for each annotation

False

Returns:

Type Description
Dict[str, Dict[str, list]]

Dict[str, Dict[str, list]]: Annotation map

Source code in recon/insights.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def get_annotation_labels(
    examples: List[Example], case_sensitive: bool = False
) -> Dict[str, Dict[str, list]]:
    """Constructs a map of each annotation in the list of examples
    to each label that annotation has and references all examples
    associated with that label.

    Args:
        examples (List[Example]): Input examples
        case_sensitive (bool, optional): Consider case of text for each annotation

    Returns:
        Dict[str, Dict[str, list]]: Annotation map
    """
    annotation_labels_map: Dict[str, Dict[str, list]] = defaultdict(
        lambda: defaultdict(list)
    )
    for e in examples:
        for s in e.spans:
            text = s.text if case_sensitive else s.text.lower()
            annotation_labels_map[text][s.label].append(e)

    return annotation_labels_map

get_ents_by_label(data, case_sensitive=False)

Get a dictionary of unique text spans by label for your data

We want to return a dictionary that maps labels to AnnotationCount objects where each AnnotationCount contains the text of the annotation text, the total number of times it's mentioned (e.g. what entity_coverage does) but also the examples it is in.

Parameters:

Name Type Description Default
data List[Example]

List of examples

required
case_sensitive bool

Consider case of text for each annotation

False

Returns:

Type Description
DefaultDict[str, DefaultDict[str, Set[Example]]]

DefaultDict[str, DefaultDict[str, Set[Example]]]: DefaultDict mapping

DefaultDict[str, DefaultDict[str, Set[Example]]]

label to sorted list of the unique spans annotated for that label.

Source code in recon/insights.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_ents_by_label(
    data: List[Example], case_sensitive: bool = False
) -> DefaultDict[str, DefaultDict[str, Set[Example]]]:
    """Get a dictionary of unique text spans by label for your data

    We want to return a dictionary that maps labels to AnnotationCount
    objects where each AnnotationCount contains the text of the annotation text,
    the total number of times it's mentioned (e.g. what entity_coverage does)
    but also the examples it is in.

    Args:
        data (List[Example]): List of examples
        case_sensitive (bool, optional): Consider case of text for each annotation

    Returns:
        DefaultDict[str, DefaultDict[str, Set[Example]]]: DefaultDict mapping
        label to sorted list of the unique spans annotated for that label.
    """
    annotations: DefaultDict[str, DefaultDict[str, Set[Example]]] = defaultdict(
        lambda: defaultdict(set)
    )
    for example in data:
        for s in example.spans:
            span_text = s.text if case_sensitive else s.text.lower()
            annotations[s.label][span_text].add(example)
    return annotations

get_hardest_examples(recognizer, examples, score_count=True, normalize_scores=True)

Get hardest examples for a recognizer to predict on and sort by difficulty with the goal of quickly identifying the biggest holes in a model / annotated data.

Parameters:

Name Type Description Default
recognizer EntityRecognizer

EntityRecognizer to test predictions for

required
examples List[Example]

Set of input examples

required
score_count bool

Adjust score by total number of errors

True
normalize_scores bool

Scale scores to range [0, 1] adjusted by total number of errors

True

Returns:

Type Description
List[ExampleDiff]

List[HardestExample]: HardestExamples sorted by difficulty (hardest first)

Source code in recon/insights.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def get_hardest_examples(
    recognizer: EntityRecognizer,
    examples: List[Example],
    score_count: bool = True,
    normalize_scores: bool = True,
) -> List[ExampleDiff]:
    """Get hardest examples for a recognizer to predict on and sort by
    difficulty with the goal of quickly identifying the biggest holes
    in a model / annotated data.

    Args:
        recognizer (EntityRecognizer): EntityRecognizer to test predictions for
        examples (List[Example]): Set of input examples
        score_count (bool): Adjust score by total number of errors
        normalize_scores (bool): Scale scores to range [0, 1] adjusted by
            total number of errors

    Returns:
        List[HardestExample]: HardestExamples sorted by difficulty (hardest first)
    """
    preds = recognizer.predict((e.text for e in examples))

    max_count = 0
    hes = []
    for pred, ref in zip(preds, examples):
        scorer = PRFScore()
        scorer.score_set(
            set([(s.start, s.end, s.label) for s in pred.spans]),
            set([(s.start, s.end, s.label) for s in ref.spans]),
        )
        total_errors = scorer.fp + scorer.fn
        score = scorer.fscore if (pred.spans and ref.spans) else 1.0

        if total_errors > max_count:
            max_count = total_errors

        he = ExampleDiff(
            reference=ref, prediction=pred, count=total_errors, score=score
        )
        hes.append(he)

    if score_count:
        for he in hes:
            he.score -= he.count / max_count
        if normalize_scores:
            scores = np.asarray([he.score for he in hes])
            scores = (scores - scores.min()) / np.ptp(scores)
            for i, he in enumerate(hes):
                he.score = scores[i]

    sorted_hes = sorted(hes, key=lambda he: (he.score, he.count))
    return sorted_hes

get_label_disparities(data, label1, label2, case_sensitive=False)

Identify annotated spans that have different labels in different examples

Parameters:

Name Type Description Default
data List[Example]

Input List of examples

required
label1 str

First label to compare

required
label2 str

Second label to compare

required
case_sensitive bool

Consider case of text for each annotation

False

Returns:

Type Description
Dict[str, List[Example]]

Dict[str, List[Example]]: Set of all unique text spans that overlap between label1 and label2

Source code in recon/insights.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def get_label_disparities(
    data: List[Example], label1: str, label2: str, case_sensitive: bool = False
) -> Dict[str, List[Example]]:
    """Identify annotated spans that have different labels in different examples

    Args:
        data (List[Example]): Input List of examples
        label1 (str): First label to compare
        label2 (str): Second label to compare
        case_sensitive (bool, optional): Consider case of text for each annotation

    Returns:
        Dict[str, List[Example]]: Set of all unique text spans that
            overlap between label1 and label2
    """
    annotations = get_ents_by_label(data, case_sensitive=case_sensitive)
    overlap = set(annotations[label1]).intersection(set(annotations[label2]))

    output = defaultdict(list)
    for ann in overlap:
        if ann in annotations[label1]:
            output[ann] += annotations[label1][ann]
        if ann in annotations[label2]:
            output[ann] += annotations[label2][ann]
    return output

top_label_disparities(data, case_sensitive=False, dedupe=False)

Identify annotated spans that have different labels in different examples for all label pairs in data.

Parameters:

Name Type Description Default
data List[Example]

Input List of examples

required
case_sensitive bool

Consider case of text for each annotation

False
dedupe bool

Whether to deduplicate for table view vs confusion matrix. False by default for easy confusion matrix display.

False

Returns:

Type Description
List[LabelDisparity]

List[LabelDisparity]: List of LabelDisparity objects for each label pair combination sorted by the number of disparities between them.

Source code in recon/insights.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def top_label_disparities(
    data: List[Example], case_sensitive: bool = False, dedupe: bool = False
) -> List[LabelDisparity]:
    """Identify annotated spans that have different labels
    in different examples for all label pairs in data.

    Args:
        data (List[Example]): Input List of examples
        case_sensitive (bool, optional): Consider case of text for each annotation
        dedupe (bool, optional): Whether to deduplicate for table
            view vs confusion matrix. False by default for easy
            confusion matrix display.

    Returns:
        List[LabelDisparity]: List of LabelDisparity objects for each
            label pair combination sorted by the number of disparities between them.
    """
    annotations = get_ents_by_label(data, case_sensitive=case_sensitive)
    label_disparities = {}
    for label1 in annotations.keys():
        for label2 in annotations.keys():
            if label1 != label2:
                intersection = set(annotations[label1]).intersection(
                    set(annotations[label2])
                )
                n_disparities = len(intersection)
                if n_disparities > 0:
                    if dedupe:
                        input_hash = "||".join(sorted([label1, label2]))
                    else:
                        input_hash = "||".join([label1, label2])

                    label_disparities[input_hash] = LabelDisparity(
                        label1=label1, label2=label2, count=n_disparities
                    )

    return sorted(label_disparities.values(), key=lambda ld: ld.count, reverse=True)

top_prediction_errors(recognizer, data, labels=[], exclude_fp=False, exclude_fn=False, verbose=False, return_examples=False)

Get a sorted list of examples your model is worst at predicting.

Parameters:

Name Type Description Default
recognizer EntityRecognizer

An instance of EntityRecognizer

required
data List[Example]

List of annotated Examples

required
labels List[str]

List of labels to get errors for. Defaults to the labels property of recognizer.

[]
exclude_fp bool

Flag to exclude False Positive errors.

False
exclude_fn bool

Flag to exclude False Negative errors.

False
verbose bool

Show verbose output.

False
return_examples bool

Return Examples that contain the entity label annotation.

False

Returns:

Type Description
List[PredictionError]

List[PredictionError]: List of Prediction Errors your model is making, sorted by the spans your model has the most trouble with.

Source code in recon/insights.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def top_prediction_errors(
    recognizer: EntityRecognizer,
    data: List[Example],
    labels: List[str] = [],
    exclude_fp: bool = False,
    exclude_fn: bool = False,
    verbose: bool = False,
    return_examples: bool = False,
) -> List[PredictionError]:
    """Get a sorted list of examples your model is worst at predicting.

    Args:
        recognizer (EntityRecognizer): An instance of EntityRecognizer
        data (List[Example]): List of annotated Examples
        labels (List[str]): List of labels to get errors for.
            Defaults to the labels property of `recognizer`.
        exclude_fp (bool, optional): Flag to exclude False Positive errors.
        exclude_fn (bool, optional): Flag to exclude False Negative errors.
        verbose (bool, optional): Show verbose output.
        return_examples (bool, optional): Return Examples that
            contain the entity label annotation.

    Returns:
        List[PredictionError]: List of Prediction Errors your model is making,
            sorted by the spans your model has the most trouble with.
    """
    labels = labels or recognizer.labels
    texts = (e.text for e in data)
    anns = (e.spans for e in data)
    preds = recognizer.predict(texts)

    errors = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))  # type: ignore
    error_examples: DefaultDict[
        Tuple[str, str, str], List[PredictionErrorExamplePair]
    ] = defaultdict(list)
    n_errors = 0

    for orig_example, pred_example, ann in zip(data, preds, anns):
        pred_error_example_pair = PredictionErrorExamplePair(
            original=orig_example, predicted=pred_example
        )

        cand = set([(s.start, s.end, s.label) for s in pred_example.spans])
        gold = set([(s.start, s.end, s.label) for s in ann])

        fp_diff = cand - gold
        fn_diff = gold - cand

        seen = set()

        if fp_diff and not exclude_fp:
            for fp in fp_diff:
                gold_ent = None
                for ge in gold:
                    if fp[0] == ge[0] and fp[1] == ge[1]:
                        gold_ent = ge
                        break
                if gold_ent:
                    start, end, label = gold_ent
                    text = pred_example.text[start:end]
                    false_label = fp[2]
                    errors[label][text][false_label] += 1
                    error_examples[(text, label, false_label)].append(
                        pred_error_example_pair
                    )
                else:
                    start, end, false_label = fp
                    text = pred_example.text[start:end]
                    errors[NOT_LABELED][text][false_label] += 1
                    error_examples[(text, NOT_LABELED, false_label)].append(
                        pred_error_example_pair
                    )
                n_errors += 1
                seen.add((start, end))

        if fn_diff and not exclude_fn:
            for fn in fn_diff:
                start, end, label = fn
                if (start, end) not in seen:
                    text = pred_example.text[start:end]
                    errors[label][text][NOT_LABELED] += 1
                    error_examples[(text, label, NOT_LABELED)].append(
                        pred_error_example_pair
                    )
                    n_errors += 1

    ranked_errors_map: Dict[Tuple[str, str, str], PredictionError] = {}

    for label, errors_per_label in errors.items():
        for error_text, error_labels in errors_per_label.items():
            for error_label, count in error_labels.items():
                pe_hash = (error_text, label, error_label)
                pe = PredictionError(
                    text=error_text,
                    true_label=label,
                    pred_label=error_label,
                    count=count,
                )
                if return_examples:
                    pe.examples = error_examples[pe_hash]
                ranked_errors_map[pe_hash] = pe

    ranked_errors: List[PredictionError] = sorted(
        list(ranked_errors_map.values()), key=lambda error: error.count, reverse=True
    )
    error_texts = set()
    for re in ranked_errors:
        if re.examples:
            for e in re.examples:
                error_texts.add(e.original.text)

    error_rate = round(len(error_texts) / len(data), 2)
    if verbose:
        error_summary = {
            "N Examples": len(data),
            "N Errors": len(ranked_errors),
            "N Error Examples": len(error_texts),
            "Error Rate": error_rate,
        }
        msg = Printer()
        msg.divider("Error Analysis")
        msg.table(error_summary)

    return ranked_errors