class ZeroShotClassification:
def __enter__(self):
from transformers import pipeline
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
@stub.function(cpu=4, retries=3)
def predict(self, tweet: str):
result = self.classifier(tweet, topics, multi_label=False)
labels = []
for i, item in enumerate(result['scores']):
if item > .5:
labels.append(result['labels'][i])
return {'text':tweet, 'labels':labels}