Roman Solomatin
commited on
update after review
Browse files- listconranker.py +101 -4
listconranker.py
CHANGED
|
@@ -30,6 +30,9 @@ from transformers import (
|
|
| 30 |
import os
|
| 31 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 32 |
from typing import Union, List, Optional
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class ListConRankerConfig(BertConfig):
|
|
@@ -295,14 +298,15 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 295 |
if sep_idxs.numel() == 0:
|
| 296 |
raise ValueError(f"No SEP in sequence {idx}")
|
| 297 |
first_sep = sep_idxs[0].item()
|
|
|
|
| 298 |
|
| 299 |
# Extract query and passage
|
| 300 |
q_seq = seq[: first_sep + 1]
|
| 301 |
q_mask = mask[: first_sep + 1]
|
| 302 |
q_tt = torch.zeros_like(q_seq)
|
| 303 |
|
| 304 |
-
p_seq = seq[first_sep:]
|
| 305 |
-
p_mask = mask[first_sep:]
|
| 306 |
p_seq = p_seq.clone()
|
| 307 |
p_seq[0] = self.config.cls_token_id
|
| 308 |
p_tt = torch.zeros_like(p_seq)
|
|
@@ -315,6 +319,16 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 315 |
].tolist()
|
| 316 |
)
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
if key not in grouped:
|
| 319 |
grouped[key] = {
|
| 320 |
"query": (q_seq, q_mask, q_tt),
|
|
@@ -396,7 +410,7 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 396 |
):
|
| 397 |
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
| 398 |
model.hf_model = BertModel.from_pretrained(
|
| 399 |
-
model_name_or_path, config=model.config.bert_config
|
| 400 |
)
|
| 401 |
|
| 402 |
linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
|
|
@@ -439,11 +453,94 @@ class ListConRankerModel(PreTrainedModel):
|
|
| 439 |
inputs = tokenizer(
|
| 440 |
batch_pairs,
|
| 441 |
padding=True,
|
| 442 |
-
truncation=
|
| 443 |
return_tensors="pt",
|
| 444 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
logits = self(**inputs)[0]
|
| 446 |
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
| 447 |
logits.squeeze(1)
|
| 448 |
)
|
| 449 |
return total_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
import os
|
| 31 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 32 |
from typing import Union, List, Optional
|
| 33 |
+
from collections import defaultdict
|
| 34 |
+
import numpy as np
|
| 35 |
+
import math
|
| 36 |
|
| 37 |
|
| 38 |
class ListConRankerConfig(BertConfig):
|
|
|
|
| 298 |
if sep_idxs.numel() == 0:
|
| 299 |
raise ValueError(f"No SEP in sequence {idx}")
|
| 300 |
first_sep = sep_idxs[0].item()
|
| 301 |
+
second_sep = sep_idxs[1].item()
|
| 302 |
|
| 303 |
# Extract query and passage
|
| 304 |
q_seq = seq[: first_sep + 1]
|
| 305 |
q_mask = mask[: first_sep + 1]
|
| 306 |
q_tt = torch.zeros_like(q_seq)
|
| 307 |
|
| 308 |
+
p_seq = seq[first_sep : second_sep + 1]
|
| 309 |
+
p_mask = mask[first_sep : second_sep + 1]
|
| 310 |
p_seq = p_seq.clone()
|
| 311 |
p_seq[0] = self.config.cls_token_id
|
| 312 |
p_tt = torch.zeros_like(p_seq)
|
|
|
|
| 319 |
].tolist()
|
| 320 |
)
|
| 321 |
|
| 322 |
+
# truncation
|
| 323 |
+
q_seq = q_seq[: self.config.max_position_embeddings]
|
| 324 |
+
q_seq[-1] = self.config.sep_token_id
|
| 325 |
+
p_seq = p_seq[: self.config.max_position_embeddings]
|
| 326 |
+
p_seq[-1] = self.config.sep_token_id
|
| 327 |
+
q_mask = q_mask[: self.config.max_position_embeddings]
|
| 328 |
+
p_mask = p_mask[: self.config.max_position_embeddings]
|
| 329 |
+
q_tt = q_tt[: self.config.max_position_embeddings]
|
| 330 |
+
p_tt = p_tt[: self.config.max_position_embeddings]
|
| 331 |
+
|
| 332 |
if key not in grouped:
|
| 333 |
grouped[key] = {
|
| 334 |
"query": (q_seq, q_mask, q_tt),
|
|
|
|
| 410 |
):
|
| 411 |
model = super().from_pretrained(model_name_or_path, config=config, **kwargs)
|
| 412 |
model.hf_model = BertModel.from_pretrained(
|
| 413 |
+
model_name_or_path, config=model.config.bert_config, **kwargs
|
| 414 |
)
|
| 415 |
|
| 416 |
linear_path = os.path.join(model_name_or_path, "linear_in_embedding.pt")
|
|
|
|
| 453 |
inputs = tokenizer(
|
| 454 |
batch_pairs,
|
| 455 |
padding=True,
|
| 456 |
+
truncation=False,
|
| 457 |
return_tensors="pt",
|
| 458 |
)
|
| 459 |
+
|
| 460 |
+
for k, v in inputs.items():
|
| 461 |
+
inputs[k] = v.to(self.device)
|
| 462 |
+
|
| 463 |
logits = self(**inputs)[0]
|
| 464 |
total_logits[batch * batch_size : (batch + 1) * batch_size] = (
|
| 465 |
logits.squeeze(1)
|
| 466 |
)
|
| 467 |
return total_logits
|
| 468 |
+
|
| 469 |
+
def multi_passage_in_iterative_inference(
|
| 470 |
+
self,
|
| 471 |
+
sentences: List[str],
|
| 472 |
+
stop_num: int = 20,
|
| 473 |
+
decrement_rate: float = 0.2,
|
| 474 |
+
min_filter_num: int = 10,
|
| 475 |
+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(
|
| 476 |
+
"ByteDance/ListConRanker"
|
| 477 |
+
),
|
| 478 |
+
):
|
| 479 |
+
"""
|
| 480 |
+
Process multiple passages for one query in iterative inference.
|
| 481 |
+
:param sentences: List contains sentences for a query.
|
| 482 |
+
:return: Tensor of logits for each passage.
|
| 483 |
+
"""
|
| 484 |
+
if stop_num < 1:
|
| 485 |
+
raise ValueError("stop_num must be greater than 0")
|
| 486 |
+
if decrement_rate <= 0 or decrement_rate >= 1:
|
| 487 |
+
raise ValueError("decrement_rate must be in (0, 1)")
|
| 488 |
+
if min_filter_num < 1:
|
| 489 |
+
raise ValueError("min_filter_num must be greater than 0")
|
| 490 |
+
|
| 491 |
+
query = sentences[0]
|
| 492 |
+
passage = sentences[1:]
|
| 493 |
+
|
| 494 |
+
filter_times = 0
|
| 495 |
+
passage2score = defaultdict(list)
|
| 496 |
+
while len(passage) > stop_num:
|
| 497 |
+
batch = [[query] + passage]
|
| 498 |
+
pred_scores = self.multi_passage(
|
| 499 |
+
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
| 500 |
+
).tolist()
|
| 501 |
+
pred_scores_argsort = np.argsort(
|
| 502 |
+
pred_scores
|
| 503 |
+
).tolist() # Sort in increasing order
|
| 504 |
+
|
| 505 |
+
passage_len = len(passage)
|
| 506 |
+
to_filter_num = math.ceil(passage_len * decrement_rate)
|
| 507 |
+
if to_filter_num < min_filter_num:
|
| 508 |
+
to_filter_num = min_filter_num
|
| 509 |
+
|
| 510 |
+
have_filter_num = 0
|
| 511 |
+
while have_filter_num < to_filter_num:
|
| 512 |
+
idx = pred_scores_argsort[have_filter_num]
|
| 513 |
+
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
| 514 |
+
have_filter_num += 1
|
| 515 |
+
while (
|
| 516 |
+
pred_scores[pred_scores_argsort[have_filter_num - 1]]
|
| 517 |
+
== pred_scores[pred_scores_argsort[have_filter_num]]
|
| 518 |
+
):
|
| 519 |
+
idx = pred_scores_argsort[have_filter_num]
|
| 520 |
+
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
| 521 |
+
have_filter_num += 1
|
| 522 |
+
next_passage = []
|
| 523 |
+
next_passage_idx = have_filter_num
|
| 524 |
+
while next_passage_idx < len(passage):
|
| 525 |
+
idx = pred_scores_argsort[next_passage_idx]
|
| 526 |
+
next_passage.append(passage[idx])
|
| 527 |
+
next_passage_idx += 1
|
| 528 |
+
passage = next_passage
|
| 529 |
+
filter_times += 1
|
| 530 |
+
|
| 531 |
+
batch = [[query] + passage]
|
| 532 |
+
pred_scores = self.multi_passage(
|
| 533 |
+
batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer
|
| 534 |
+
).tolist()
|
| 535 |
+
|
| 536 |
+
cnt = 0
|
| 537 |
+
while cnt < len(passage):
|
| 538 |
+
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
| 539 |
+
cnt += 1
|
| 540 |
+
|
| 541 |
+
passage = sentences[1:]
|
| 542 |
+
final_score = []
|
| 543 |
+
for i in range(len(passage)):
|
| 544 |
+
p = passage[i]
|
| 545 |
+
final_score.append(passage2score[p][0])
|
| 546 |
+
return final_score
|