| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This script provides a simple evaluation pipeline for multiple-choice reasoning datasets |
| (e.g., BoolQ, HellaSwag, ARC, OpenBookQA, Winogrande) with different composition strategies. |
| |
| Usage examples: |
| python arrow_phi3_mini.py --strategy base --ds_name arc-challenge |
| python arrow_phi3_mini.py --strategy arrow --ds_name boolq |
| python arrow_phi3_mini.py --strategy gks --ds_name hswag |
| |
| Key features: |
| - Supports three strategies: |
| • "base" → Evaluate the quantized base model directly |
| • "arrow" → Use Arrow modular routing with task-specific adapters |
| • "gks" → Use Arrow + GenKnowSub (subtracting general-domain knowledge) |
| - Loads evaluation datasets from the Hugging Face Hub |
| - Implements a batched evaluation loop that computes per-option likelihoods and selects |
| the answer with the lowest average loss |
| - Reports simple accuracy |
| |
| Implementation details: |
| - The base model is quantized to 4-bit using `BitsAndBytesConfig` (nf4, bf16 compute). |
| - For Arrow and GKS, task-specific adapters are loaded from the Hugging Face Hub: |
| TahaBa/phi3-mini-clustered-flan/ts_expert_i |
| - Task-specific adapters were trained on 10 clusters of FLAN tasks. |
| - The clusters were created using Model-Based Clustering (MBC): |
| 1. Train a LoRA adapter for each individual task. |
| 2. Apply k-means clustering to group tasks based on these adapters. |
| 3. Train a LoRA adapter for each resulting cluster. |
| For more details, see the Arrow paper: https://huggingface.co/papers/2405.11157 |
| |
| - For GKS, general adapters are loaded from: |
| TahaBa/phi3-mini-general-adapters/... |
| - These adapters were trained on English, French, and German Wikipedia data |
| using a causal language modeling objective with (507-token context → 5-token completion) pairs. |
| - This setup encodes general knowledge into the LoRA space, which can then be |
| subtracted from task-specific adapters during inference to isolate and purify them. |
| For more details, see the GenKnowSub paper: https://huggingface.co/papers/2505.10939 |
| |
| - `evaluate_on_multi_choice_batched` handles tokenization, masking context tokens, |
| and computing per-choice log-likelihoods for fair comparison. |
| - Accuracy is printed at the end for the selected dataset. |
| |
| This script is mainly meant for demonstration purposes and lightweight evaluation, |
| not full-scale benchmarking (batch size / max length can be tuned). |
| |
| ======================================================================================= |
| |
| Results (evaluated with microsoft/Phi-3-mini-4k-instruct, 4-bit quantization): |
| |
| | Dataset | Base Acc. | Arrow Acc. | Arrow+GKS Acc. | |
| |--------------|-----------|------------|----------------| |
| | ARC-Challenge| 0.4515 | 0.5418 | 0.5585 | |
| | ARC-Easy | 0.6894 | 0.8404 | 0.8473 | |
| | Winogrande | 0.5769 | 0.6550 | 0.6724 | |
| | BoolQ | 0.8146 | 0.8030 | 0.8247 | |
| | OpenBookQA | 0.43 | 0.448 | 0.472 | |
| | HellaSwag | 0.7318 | 0.7150 | 0.7376 | |
| |
| Observations: |
| - Arrow generally improves over the base model by routing tokens to the most relevant task adapters. |
| - Applying GKS (general knowledge subtraction) consistently gives further gains compared to Arrow and Base. |
| |
| These numbers are not meant as leaderboard results, but as a sanity check |
| to verify that the implementation works as expected and demonstrates |
| the benefits of Arrow and GenKnowSub. |
| """ |
|
|
| import argparse |
| import random |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from sklearn.metrics import accuracy_score |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
| from peft import ArrowConfig, create_arrow_model |
|
|
|
|
| MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" |
| MODEL_MAX_LEN = 2048 |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Training script with strategy selection") |
|
|
| parser.add_argument( |
| "--strategy", |
| type=str, |
| choices=["base", "arrow", "gks"], |
| default="base", |
| help="Training strategy to use: base, arrow, or gks", |
| ) |
| parser.add_argument( |
| "--ds_name", |
| type=str, |
| choices=["boolq", "hswag", "arc-easy", "arc-challenge", "oqa", "wg"], |
| default="arc-challenge", |
| help="Dataset to use: boolq, hswag, arc-easy, arc-challenge, oqa, wg", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def read_test_dataset(ds_name): |
| if ds_name == "boolq": |
| ds = load_dataset("google/boolq", split="validation", trust_remote_code=True) |
| elif ds_name == "hswag": |
| ds = load_dataset("Rowan/hellaswag", split="validation", trust_remote_code=True) |
| elif ds_name == "arc-challenge": |
| ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="validation", trust_remote_code=True) |
| elif ds_name == "arc-easy": |
| ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="validation", trust_remote_code=True) |
| elif ds_name == "oqa": |
| ds = load_dataset("allenai/openbookqa", split="validation", trust_remote_code=True) |
| elif ds_name == "wg": |
| ds = load_dataset("allenai/winogrande", "winogrande_xl", split="validation", trust_remote_code=True) |
| else: |
| raise f"Dataset {ds_name} is not supported yet." |
|
|
| return ds |
|
|
|
|
| def extract_input_content(ds_name, row): |
| if ds_name == "boolq": |
| return f"[passage]{row['passage']}[question]{row['question']}" |
| if ds_name == "hswag": |
| return row["ctx"] |
| if (ds_name == "arc-challenge") or (ds_name == "arc-easy"): |
| return row["question"] |
| if ds_name == "oqa": |
| return row["question_stem"] |
| if ds_name == "wg": |
| return row["sentence"] |
|
|
|
|
| def create_multi_choice_options(row, ds_name): |
| options_texts = [] |
| content = extract_input_content(ds_name, row) |
| if ds_name == "boolq": |
| choices = ["true", "false"] |
| if ds_name == "hswag": |
| choices = row["endings"] |
| if (ds_name == "arc-challenge") or (ds_name == "arc-easy"): |
| choices = row["choices"]["text"] |
| if ds_name == "wg": |
| choices = [row["option1"], row["option2"]] |
| if ds_name == "oqa": |
| choices = row["choices"]["text"] |
|
|
| for choice in choices: |
| options_texts.append(f"<|user|>\n{content}<|end|>\n<|assistant|>{choice}<|end|>\n") |
|
|
| return options_texts |
|
|
|
|
| def extract_multi_choice_target_index(row, ds_name): |
| if ds_name == "boolq": |
| return 0 if row["answer"] is True else 1 |
| if ds_name == "hswag": |
| return int(row["label"]) |
| if (ds_name == "arc-challenge") or (ds_name == "arc-easy"): |
| return row["choices"]["label"].index(row["answerKey"]) |
| if ds_name == "wg": |
| return int(row["answer"]) - 1 |
| if ds_name == "oqa": |
| return row["choices"]["label"].index(row["answerKey"]) |
|
|
|
|
| def set_seed(seed: int): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def compute_loglike_loss(logits, labels, reduction="none"): |
| bs = logits.size(0) |
| vocab_size = logits.size(-1) |
| labels = labels.squeeze(-1) |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction) |
| shift_logits = shift_logits.view(-1, vocab_size) |
| shift_labels = shift_labels.view(-1) |
|
|
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| |
| if reduction == "none": |
| loss = loss.view((bs, -1)) |
| non_zero_loss = (loss != 0).sum(dim=-1) |
| non_zero_loss[non_zero_loss == 0] = 1 |
| loss = loss.sum(dim=-1) / non_zero_loss |
|
|
| return loss.float() |
|
|
|
|
| def evaluate_on_multi_choice_batched( |
| eval_dataset, model, tokenizer, ds_name, labels, predictions, args, batch_size=32, max_length=512, device="cuda" |
| ): |
| |
| model.eval() |
|
|
| for start in tqdm( |
| range(0, len(eval_dataset), batch_size), total=(len(eval_dataset) + batch_size - 1) // batch_size |
| ): |
| rows = [eval_dataset[i] for i in range(start, min(start + batch_size, len(eval_dataset)))] |
|
|
| |
| all_texts = [] |
| options_per_sample = [] |
| ctx_lens_per_option = [] |
|
|
| for row in rows: |
| |
| options = create_multi_choice_options(row, ds_name) |
| options_per_sample.append(len(options)) |
|
|
| |
| content = extract_input_content(ds_name, row) |
| context_prompt = f"<|user|>\n{content}<|end|>\n<|assistant|>" |
| ctx_len = len(tokenizer.encode(context_prompt)) - 1 |
|
|
| all_texts.extend(options) |
| ctx_lens_per_option.extend([ctx_len] * len(options)) |
|
|
| |
| labels.append(extract_multi_choice_target_index(row, ds_name)) |
|
|
| |
| tokenized = tokenizer( |
| all_texts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
| tokenized = {k: v.to(device) for k, v in tokenized.items()} |
|
|
| |
| masked_labels = tokenized["input_ids"].clone() |
| for i, ctx_len in enumerate(ctx_lens_per_option): |
| masked_labels[i, :ctx_len] = -100 |
| masked_labels[tokenized["attention_mask"] == 0] = -100 |
|
|
| with torch.no_grad(): |
| logits = model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"]).logits |
| |
| losses = compute_loglike_loss(logits, masked_labels, reduction="none").detach().cpu() |
|
|
| |
| idx = 0 |
| for n_opt in options_per_sample: |
| pred = torch.argmin(losses[idx : idx + n_opt]).item() |
| predictions.append(pred) |
| idx += n_opt |
|
|
| print( |
| f"Accuracy for dataset {args.ds_name} and strategy {args.strategy} is: {accuracy_score(labels, predictions)}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| print(f"Selected strategy: {args.strategy}") |
| print(f"Dataset name: {args.ds_name}") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_NAME, |
| use_fast=True, |
| padding_side="right", |
| model_max_length=MODEL_MAX_LEN, |
| ) |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=False, |
| ) |
|
|
| |
| base_model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| quantization_config=bnb_config, |
| ) |
|
|
| |
| test_dataset = read_test_dataset(args.ds_name) |
| print(f"{args.ds_name} is loaded with size: {len(test_dataset)}.") |
|
|
| labels, predictions = [], [] |
| if args.strategy == "base": |
| |
| with torch.no_grad(): |
| evaluate_on_multi_choice_batched( |
| test_dataset, |
| base_model, |
| tokenizer, |
| args.ds_name, |
| labels, |
| predictions, |
| args, |
| batch_size=64, |
| max_length=512, |
| device="cuda", |
| ) |
| else: |
| general_adapter_paths = [] |
| if args.strategy == "gks": |
| arrow_config = ArrowConfig( |
| top_k=3, |
| router_temperature=1.0, |
| use_gks=True, |
| ) |
| |
| general_adapter_paths = [ |
| "TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langen/checkpoint-17", |
| "TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langfr/checkpoint-35", |
| "TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langger/checkpoint-17", |
| ] |
| else: |
| arrow_config = ArrowConfig( |
| top_k=3, |
| router_temperature=1.0, |
| ) |
|
|
| |
| task_specific_adapter_paths = [f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)] |
|
|
| |
| model = create_arrow_model( |
| base_model=base_model, |
| task_specific_adapter_paths=task_specific_adapter_paths, |
| general_adapter_paths=general_adapter_paths, |
| arrow_config=arrow_config, |
| ) |
|
|
| |
| with torch.no_grad(): |
| evaluate_on_multi_choice_batched( |
| test_dataset, |
| model, |
| tokenizer, |
| args.ds_name, |
| labels, |
| predictions, |
| args, |
| batch_size=32, |
| max_length=512, |
| device="cuda", |
| ) |
|
|