You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from captum.attr import (
FeatureAblation,
ShapleyValues,
LayerIntegratedGradients,
LLMAttribution,
LLMGradientAttribution,
TextTokenInput,
TextTemplateInput,
ProductBaselines,
)
model_name = "meta-llama/Llama-3.2-1B-Instruct"
def load_model(model_name, bnb_config):
n_gpus = torch.cuda.device_count()
max_memory = "10000MB"
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto", # dispatch efficiently the model on the available ressources
max_memory = {i: max_memory for i in range(n_gpus)},
)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)
# Needed for LLaMA tokenizer
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def create_bnb_config():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return bnb_config
model, tokenizer = load_model(model_name, bnb_config)
model.eval()
def prompt_fn(*examples):
main_prompt = "Decide if the following movie review enclosed in quotes is Positive or Negative:\n'I really liked the Avengers, it had a captivating plot!'\nReply only Positive or Negative."
subset = [elem for elem in examples if elem]
if not subset:
prompt = main_prompt
else:
prefix = "Here are some examples of movie reviews and classification of whether they were Positive or Negative:\n"
prompt = prefix + " \n".join(subset) + "\n " + main_prompt
return "[INST] " + prompt + "[/INST]"
input_examples = [
"'The movie was ok, the actors weren't great' Negative",
"'I loved it, it was an amazing story!' Positive",
"'Total waste of time!!' Negative",
"'Won't recommend' Negative",
]
sv = ShapleyValues(model)
sv_llm_attr = LLMAttribution(sv, tokenizer)
#attr_res = sv_llm_attr.attribute(inp, target=target, num_trials=3)
inp = TextTemplateInput(
prompt_fn,
values=input_examples,
)
attr_res = sv_llm_attr.attribute(inp)
attr_res.plot_token_attr(show=True)
`
Expected behavior
It should generate 'postive' or 'negtive' .And plot the score between prompt’s example and output.
The actual output
The system gives a hint "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:128001 for open-end generation."
And the ploted picture is
Environment
Describe the environment used for Captum
- Captum : 0.7.0:
- google Colab : Ubuntu 22.04.3 LTS"
- The method installed Captum / PyTorch : pip
- Python version: 3.10.12
- CUDA/cuDNN version:
`
cuda-python 12.2.1
cupy-cuda12x 12.2.0
jax-cuda12-pjrt 0.4.33
jax-cuda12-plugin 0.4.33
nvidia-cuda-cupti-cu12 12.6.80
nvidia-cuda-nvcc-cu12 12.6.77
nvidia-cuda-runtime-cu12 12.6.77`
- GPU models and configuration:
- Pytorch : 2.4.1+cu121
- Transformer: 4.44.2
## The Possible problem
attention_mask is incorrect when captum calls model.generate?
The text was updated successfully, but these errors were encountered:
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
I followed https://captum.ai/tutorials/Llama2_LLM_Attribution
My code is here,the only difference is I changed the model_name.
`
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
`
Expected behavior
It should generate 'postive' or 'negtive' .And plot the score between prompt’s example and output.
The actual output
The system gives a hint "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's
attention_mask
to obtain reliable results.Setting
pad_token_id
toeos_token_id
:128001 for open-end generation."And the ploted picture is
Environment
Describe the environment used for Captum
The text was updated successfully, but these errors were encountered: