Commit bd414ae2 authored by Pado's avatar Pado
Browse files

Generate evaluation table as for German (if observed labels are given)

parent 78131ec0
...@@ -33,6 +33,7 @@ import torch ...@@ -33,6 +33,7 @@ import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from sklearn.metrics import f1_score, accuracy_score from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import classification_report
from transformers import ( from transformers import (
BertConfig, BertConfig,
...@@ -491,10 +492,53 @@ def print_predictions(args, preds): ...@@ -491,10 +492,53 @@ def print_predictions(args, preds):
processor.get_test_examples(args.data_dir) processor.get_test_examples(args.data_dir)
) )
# observed grade list created
obs_grade = [ex.label for ex in examples]
# suggested grade list created
sugg_grade = ['correct' if pred == 0 else 'incorrect' for pred in preds]
# flag: do observed grades exist?
count=0
# Check if obs_grade contains "NONE" values or is empty
if not obs_grade or all(grade == 'NONE' for grade in obs_grade):
count += 1
else:
# classification report
classification_rep = classification_report(obs_grade, sugg_grade)
report_string = classification_rep
report_lines = report_string.split('\n')
# print(report_lines)
# accuracy line
formatted_accuracy_line = "\t".join(report_lines[5].split())
formatted_acc_line_with_tabs = (formatted_accuracy_line[:formatted_accuracy_line.index('\t',
formatted_accuracy_line.index(
'\t'))] + '\t\t' +
formatted_accuracy_line[
formatted_accuracy_line.index('\t', formatted_accuracy_line.index('\t')):])
# #weighted avg printing
#
wt_avg_line = "\t".join(report_lines[7].split())
new_wt_avg_line = wt_avg_line.replace("\t", " ", 1)
# Join the entire newly formatted list into a single string
formatted_output = "\n".join([
"\t precision \t recall \t f1-score \t support",
"\t".join(report_lines[2].split()),
"\t".join(report_lines[3].split()),
formatted_acc_line_with_tabs,
new_wt_avg_line
])
with open(args.data_dir + "/" + dir_name + "/predictions.txt", "w", encoding="utf8") as writer: with open(args.data_dir + "/" + dir_name + "/predictions.txt", "w", encoding="utf8") as writer:
# print("# examples: " + str(len(examples)))
# print("# labels: " + str(len(labels)))
# print("# preds: " + str(len(preds)))
writer.write( writer.write(
"question\treferenceAnswer\tstudentAnswer\tsuggested grade\tobserved grade\n") "question\treferenceAnswer\tstudentAnswer\tsuggested grade\tobserved grade\n")
...@@ -519,7 +563,17 @@ def print_predictions(args, preds): ...@@ -519,7 +563,17 @@ def print_predictions(args, preds):
+ examples[i].label + examples[i].label
+ "\n" + "\n"
) )
# else: print("Labels don't match! "+str(i)+": "+str(examples[i].label)+" "+str(labels[i]))
if count == 1:
writer.write("\nClassification Report cannot be printed as observed grade column is empty or filled "
"with 'NONE' or 'none' values\n")
else:
# Write the classification report to the file
writer.write(
"\nClassification Report - high Precision for classes correct or incorrect indicates that the class prediction is reliable:\n")
writer.write(formatted_output)
if __name__ == "__main__": if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment