class ZeroShotClassification: def __enter__(self): from transformers import pipeline self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") @stub.function(cpu=4, retries=3) def predict(self, tweet: str): result = self.classifier(tweet, topics, multi_label=False) labels = [] for i, item in enumerate(result['scores']): if item > .5: labels.append(result['labels'][i]) return {'text':tweet, 'labels':labels}
Hosted onDeepnote