Generative AI Model for text summarization
import torch
from transformers import BartForConditionalGeneration, BartTokenizer, Trainer, TrainingArguments
# Define your dataset and dataloader (not provided here, as it depends on your data format)
# Load the BART model and tokenizer
model_name = "facebook/bart-large-cnn" # You can choose a different model
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
# Set up training arguments
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3, # Adjust as needed
per_device_train_batch_size=4, # Adjust as needed
save_steps=10_000, # Save the model checkpoint after a certain number of steps
logging_steps=100, # Log training progress every N steps
evaluation_strategy="steps", # Evaluate every N steps
eval_steps=1000, # Evaluate every N steps
save_total_limit=5, # Limit the number of saved checkpoints
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=None, # Define your data collator
train_dataset=None, # Provide your training dataset
data_collator=None, # Define your data collator
tokenizer=tokenizer,
compute_metrics=None, # Define your evaluation metrics
)
# Train the model
trainer.train()
# Save the trained model
trainer.save_model()
# Load the trained model
model = BartForConditionalGeneration.from_pretrained("output")
# Testing (Generating summaries)
input_text = "Your input text here..." # Provide your input text
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True)
# Generate the summary
output_ids = model.generate(input_ids, max_length=150, num_return_sequences=1)
# Decode and print the generated summary
generated_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("Generated Summary:", generated_summary)
Comments
Post a Comment