-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path2_representation.py
27 lines (23 loc) · 1.16 KB
/
2_representation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from transformers import BertForSequenceClassification, AutoConfig, AutoTokenizer, BertTokenizer, pipeline
import pandas as pd
print('Imported packages')
# Check if CUDA (GPU support) is available, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
print('Initiating models...')
model_path = '/home/igevers/representative_claims_public/code/BERTje_representation'
model = BertForSequenceClassification.from_pretrained(model_path).to(device)
tokenizer = BertTokenizer.from_pretrained('GroNLP/bert-base-dutch-cased',truncation=True, max_length=512)
clf = pipeline("text-classification",model, tokenizer=tokenizer, max_length=512, truncation=True, device=0 if device == 'cuda' else -1)
print('Initiated model')
print('Loading data...')
input_data = pd.read_csv('data/final_data_object.xlsx')
input_data['Message_x'] = input_data['Message_x'].astype(str)
print('Loaded data')
print('Making predictions...')
preds = clf(input_data['Message_x'].tolist())
input_data['preds_representation'] = [i['label'][-1] for i in preds]
print('Made predictions')
input_data.to_csv('data/final_data_object_rep.csv')
print('Saved data')