Text Generation
Transformers
PyTorch
English
prot2text
feature-extraction
Causal Language Modeling
GPT2
ESM2
Proteins
GNN
custom_code
Instructions to use habdine/Prot2Text-Medium-v1-0 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use habdine/Prot2Text-Medium-v1-0 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="habdine/Prot2Text-Medium-v1-0", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("habdine/Prot2Text-Medium-v1-0", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use habdine/Prot2Text-Medium-v1-0 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "habdine/Prot2Text-Medium-v1-0" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "habdine/Prot2Text-Medium-v1-0", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/habdine/Prot2Text-Medium-v1-0
- SGLang
How to use habdine/Prot2Text-Medium-v1-0 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "habdine/Prot2Text-Medium-v1-0" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "habdine/Prot2Text-Medium-v1-0", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "habdine/Prot2Text-Medium-v1-0" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "habdine/Prot2Text-Medium-v1-0", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use habdine/Prot2Text-Medium-v1-0 with Docker Model Runner:
docker model run hf.co/habdine/Prot2Text-Medium-v1-0
| import torch.nn as nn | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP | |
| from typing import Optional, Tuple, Union, Any, Dict, List | |
| from transformers import Seq2SeqTrainer, GPT2LMHeadModel | |
| from torch.utils.data.distributed import DistributedSampler | |
| import torch | |
| from transformers.deepspeed import is_deepspeed_zero3_enabled | |
| from transformers.generation.logits_process import LogitsProcessorList | |
| from transformers.generation.stopping_criteria import StoppingCriteriaList | |
| from transformers.generation.utils import GreedySearchOutput, GreedySearchEncoderDecoderOutput, BeamSearchOutput, BeamSearchEncoderDecoderOutput | |
| from transformers.generation.beam_search import BeamScorer | |
| try: | |
| from torch_geometric.loader import DataLoader | |
| from torch_geometric.data import Dataset | |
| except ImportError: | |
| raise Exception('You need to install torch geometric and its dependecies to use this model please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html') | |
| class _GPT2LMHeadModel(GPT2LMHeadModel): | |
| def _init_(self, config): | |
| super(GPT2LMHeadModel, self).init_(config) | |
| self.config = config | |
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, encoder_outputs=None, **kwargs): | |
| ''' | |
| This function is an edited version of the prepare_inputs_for_generation function from HuggingFace's transformers | |
| https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py | |
| ''' | |
| token_type_ids = kwargs.get("token_type_ids", None) | |
| # only last token for inputs_ids if past is defined in kwargs | |
| if past_key_values: | |
| input_ids = input_ids[:, -1].unsqueeze(-1) | |
| if token_type_ids is not None: | |
| token_type_ids = token_type_ids[:, -1].unsqueeze(-1) | |
| attention_mask = kwargs.get("attention_mask", None) | |
| position_ids = kwargs.get("position_ids", None) | |
| if self.config.prot2text_version=="1.1" or self.config.prot2text_version=="1.2": | |
| encoder_attention_mask = kwargs.get("encoder_attention_mask", None) | |
| elif self.config.prot2text_version=="1.0": | |
| encoder_attention_mask = None | |
| if attention_mask is not None and position_ids is None: | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -1].unsqueeze(-1) | |
| else: | |
| position_ids = None | |
| model_specific_kwargs = { | |
| "encoder_hidden_states": encoder_outputs['hidden_states'], | |
| } | |
| return { | |
| "input_ids": input_ids, | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "position_ids": position_ids, | |
| "attention_mask": attention_mask, | |
| "token_type_ids": token_type_ids, | |
| "encoder_attention_mask": encoder_attention_mask, | |
| **model_specific_kwargs | |
| } | |
| def greedy_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_scores: Optional[bool] = None, | |
| return_dict_in_generate: Optional[bool] = None, | |
| synced_gpus: bool = False, | |
| streamer: Optional["BaseStreamer"] = None, | |
| **model_kwargs, | |
| ) -> Union[GreedySearchOutput, torch.LongTensor]: | |
| ''' | |
| This function is an edited version of the greedy_search function from HuggingFace's transformers | |
| https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py | |
| ''' | |
| # init values | |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
| if max_length is not None: | |
| warnings.warn( | |
| "`max_length` is deprecated in this function, use" | |
| " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", | |
| UserWarning, | |
| ) | |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None | |
| output_scores = output_scores if output_scores is not None else self.generation_config.output_scores | |
| output_attentions = ( | |
| output_attentions if output_attentions is not None else self.generation_config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states | |
| ) | |
| return_dict_in_generate = ( | |
| return_dict_in_generate | |
| if return_dict_in_generate is not None | |
| else self.generation_config.return_dict_in_generate | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| # prepare model inputs | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # pre-process distribution | |
| next_tokens_scores = logits_processor(input_ids, next_token_logits) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_tokens_scores,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| # argmax | |
| next_tokens = torch.argmax(next_tokens_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| if eos_token_id is not None: | |
| if pad_token_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
| # update generated ids, model inputs, and length for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| if streamer is not None: | |
| streamer.put(next_tokens.cpu()) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| # if eos_token was found in one sentence, set sentence to finished | |
| if eos_token_id_tensor is not None: | |
| unfinished_sequences = unfinished_sequences.mul( | |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | |
| ) | |
| # stop when each sentence is finished | |
| if unfinished_sequences.max() == 0: | |
| this_peer_finished = True | |
| # stop if we exceed the maximum length | |
| try: | |
| if stopping_criteria(input_ids, scores): | |
| this_peer_finished = True | |
| except: | |
| if all(stopping_criteria(input_ids, scores)): | |
| this_peer_finished = True | |
| if this_peer_finished and not synced_gpus: | |
| break | |
| if streamer is not None: | |
| streamer.end() | |
| if return_dict_in_generate: | |
| if self.config.is_encoder_decoder: | |
| return GreedySearchEncoderDecoderOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return GreedySearchDecoderOnlyOutput( | |
| sequences=input_ids, | |
| scores=scores, | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return input_ids | |
| def _greedy_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_scores: Optional[bool] = None, | |
| return_dict_in_generate: Optional[bool] = None, | |
| synced_gpus: bool = False, | |
| streamer: Optional["BaseStreamer"] = None, | |
| **model_kwargs, | |
| ) -> Union[GreedySearchOutput, torch.LongTensor]: | |
| return self.greedy_search( | |
| input_ids, | |
| logits_processor, | |
| stopping_criteria, | |
| max_length, | |
| pad_token_id, | |
| eos_token_id, | |
| output_attentions, | |
| output_hidden_states, | |
| output_scores, | |
| return_dict_in_generate, | |
| synced_gpus, | |
| streamer, | |
| **model_kwargs, | |
| ) | |
| def _beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| beam_scorer: BeamScorer, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_scores: Optional[bool] = None, | |
| return_dict_in_generate: Optional[bool] = None, | |
| synced_gpus: bool = False, | |
| **model_kwargs, | |
| ) -> Union[BeamSearchOutput, torch.LongTensor]: | |
| return self.beam_search( | |
| input_ids, | |
| beam_scorer, | |
| logits_processor, | |
| stopping_criteria, | |
| max_length, | |
| pad_token_id, | |
| eos_token_id, | |
| output_attentions, | |
| output_hidden_states, | |
| output_scores, | |
| return_dict_in_generate, | |
| synced_gpus, | |
| **model_kwargs, | |
| ) | |
| def beam_search( | |
| self, | |
| input_ids: torch.LongTensor, | |
| beam_scorer: BeamScorer, | |
| logits_processor: Optional[LogitsProcessorList] = None, | |
| stopping_criteria: Optional[StoppingCriteriaList] = None, | |
| max_length: Optional[int] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[Union[int, List[int]]] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| output_scores: Optional[bool] = None, | |
| return_dict_in_generate: Optional[bool] = None, | |
| synced_gpus: bool = False, | |
| **model_kwargs, | |
| ) -> Union[BeamSearchOutput, torch.LongTensor]: | |
| ''' | |
| This function is an edited version of the beam_search function from HuggingFace's transformers | |
| https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py | |
| ''' | |
| # init values | |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
| if max_length is not None: | |
| warnings.warn( | |
| "`max_length` is deprecated in this function, use" | |
| " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
| UserWarning, | |
| ) | |
| stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
| if len(stopping_criteria) == 0: | |
| warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) | |
| pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |
| eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| output_scores = output_scores if output_scores is not None else self.generation_config.output_scores | |
| output_attentions = ( | |
| output_attentions if output_attentions is not None else self.generation_config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states | |
| ) | |
| return_dict_in_generate = ( | |
| return_dict_in_generate | |
| if return_dict_in_generate is not None | |
| else self.generation_config.return_dict_in_generate | |
| ) | |
| batch_size = len(beam_scorer._beam_hyps) | |
| num_beams = beam_scorer.num_beams | |
| batch_beam_size, cur_len = input_ids.shape | |
| if num_beams * batch_size != batch_beam_size: | |
| raise ValueError( | |
| f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
| ) | |
| # init attention / hidden states / scores tuples | |
| scores = () if (return_dict_in_generate and output_scores) else None | |
| beam_indices = ( | |
| tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None | |
| ) | |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
| decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
| # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
| if return_dict_in_generate and self.config.is_encoder_decoder: | |
| encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
| encoder_hidden_states = ( | |
| model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
| ) | |
| # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens | |
| # of the first beam are considered to avoid sampling the exact same tokens across all beams. | |
| beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
| beam_scores[:, 1:] = -1e9 | |
| beam_scores = beam_scores.view((batch_size * num_beams,)) | |
| this_peer_finished = False # used by synced_gpus only | |
| while True: | |
| if synced_gpus: | |
| # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
| # The following logic allows an early break if all peers finished generating their sequence | |
| this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
| # send 0.0 if we finished, 1.0 otherwise | |
| dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
| # did all peers finish? the reduced sum will be 0.0 then | |
| if this_peer_finished_flag.item() == 0.0: | |
| break | |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
| outputs = self( | |
| **model_inputs, | |
| return_dict=True, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| if synced_gpus and this_peer_finished: | |
| cur_len = cur_len + 1 | |
| continue # don't waste resources running the code we don't need | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
| # cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
| # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
| next_token_scores = nn.functional.log_softmax( | |
| next_token_logits, dim=-1 | |
| ) # (batch_size * num_beams, vocab_size) | |
| next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
| # next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) | |
| next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( | |
| next_token_scores_processed | |
| ) | |
| # Store scores, attentions and hidden_states when required | |
| if return_dict_in_generate: | |
| if output_scores: | |
| scores += (next_token_scores_processed,) | |
| if output_attentions: | |
| decoder_attentions += ( | |
| (outputs.decoder_attentions,) if not self.config.is_encoder_decoder else (outputs.attentions,) | |
| ) | |
| if self.config.is_encoder_decoder: | |
| cross_attentions += (outputs.cross_attentions,) | |
| if output_hidden_states: | |
| decoder_hidden_states += ( | |
| (outputs.decoder_hidden_states,) | |
| if self.config.is_encoder_decoder | |
| else (outputs.hidden_states,) | |
| ) | |
| # reshape for beam search | |
| vocab_size = next_token_scores.shape[-1] | |
| next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
| # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) | |
| next_token_scores, next_tokens = torch.topk( | |
| next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
| ) | |
| next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") | |
| next_tokens = next_tokens % vocab_size | |
| # stateless | |
| beam_outputs = beam_scorer.process( | |
| input_ids, | |
| next_token_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| beam_indices=beam_indices, | |
| ) | |
| beam_scores = beam_outputs["next_beam_scores"] | |
| beam_next_tokens = beam_outputs["next_beam_tokens"] | |
| beam_idx = beam_outputs["next_beam_indices"] | |
| input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
| model_kwargs = self._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
| ) | |
| if model_kwargs["past_key_values"] is not None: | |
| model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) | |
| if return_dict_in_generate and output_scores: | |
| beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) | |
| # increase cur_len | |
| cur_len = cur_len + 1 | |
| try: | |
| if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| except: | |
| if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): | |
| if not synced_gpus: | |
| break | |
| else: | |
| this_peer_finished = True | |
| sequence_outputs = beam_scorer.finalize( | |
| input_ids, | |
| beam_scores, | |
| next_tokens, | |
| next_indices, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| max_length=stopping_criteria.max_length, | |
| beam_indices=beam_indices, | |
| ) | |
| if return_dict_in_generate: | |
| if not output_scores: | |
| sequence_outputs["sequence_scores"] = None | |
| if self.config.is_encoder_decoder: | |
| return BeamSearchEncoderDecoderOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| beam_indices=sequence_outputs["beam_indices"], | |
| encoder_attentions=encoder_attentions, | |
| encoder_hidden_states=encoder_hidden_states, | |
| decoder_attentions=decoder_attentions, | |
| cross_attentions=cross_attentions, | |
| decoder_hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return BeamSearchDecoderOnlyOutput( | |
| sequences=sequence_outputs["sequences"], | |
| sequences_scores=sequence_outputs["sequence_scores"], | |
| scores=scores, | |
| beam_indices=sequence_outputs["beam_indices"], | |
| attentions=decoder_attentions, | |
| hidden_states=decoder_hidden_states, | |
| ) | |
| else: | |
| return sequence_outputs["sequences"] | |
| class CABlock(nn.Module): | |
| ''' | |
| This function is an edited version of the gpt2 decoder block function from HuggingFace's transformers | |
| https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py | |
| ''' | |
| def __init__(self, config, layer_idx=None): | |
| super().__init__() | |
| hidden_size = config.hidden_size | |
| inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size | |
| self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | |
| self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) | |
| self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | |
| self.mlp = GPT2MLP(inner_dim, config) | |
| def forward( | |
| self, | |
| hidden_states: Optional[Tuple[torch.FloatTensor]], | |
| layer_past: Optional[Tuple[torch.Tensor]] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = False, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: | |
| residual = hidden_states | |
| hidden_states = self.ln_cross_attn(hidden_states) | |
| cross_attn_outputs = self.crossattention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| head_mask=head_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| attn_output = cross_attn_outputs[0] | |
| # residual connection | |
| hidden_states = residual + attn_output | |
| residual = hidden_states | |
| hidden_states = self.ln_2(hidden_states) | |
| feed_forward_hidden_states = self.mlp(hidden_states) | |
| # residual connection | |
| hidden_states = residual + feed_forward_hidden_states | |
| return (hidden_states,) | |
| class Prot2TextTrainer(Seq2SeqTrainer): | |
| ''' | |
| This function is an edited version of the Seq2SeqTrainer from HuggingFace's transformers | |
| ''' | |
| def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: | |
| if self.args.world_size > 1: | |
| eval_sampler = DistributedSampler(self.eval_dataset, num_replicas=self.args.world_size, rank=self.args.process_index) | |
| else: | |
| eval_sampler = None | |
| return DataLoader( | |
| self.eval_dataset, | |
| batch_size=self.args.eval_batch_size, | |
| collate_fn=None, | |
| num_workers=self.args.dataloader_num_workers, | |
| pin_memory=self.args.dataloader_pin_memory, | |
| sampler=eval_sampler, | |
| ) | |
| def get_train_dataloader(self) -> DataLoader: | |
| if self.args.world_size > 1: | |
| train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index) | |
| else: | |
| train_sampler = None | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.args.per_device_train_batch_size, | |
| collate_fn=None, | |
| num_workers=self.args.dataloader_num_workers, | |
| pin_memory=self.args.dataloader_pin_memory, | |
| sampler=train_sampler, | |
| ) | |
| def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: | |
| """ | |
| Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and | |
| handling potential state. | |
| """ | |
| inputs = self._prepare_input(inputs) | |
| if len(inputs) == 0: | |
| raise ValueError( | |
| "The batch received was empty, your model won't be able to train on it. Double-check that your " | |
| f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." | |
| ) | |
| if self.args.past_index >= 0 and self._past is not None: | |
| inputs["mems"] = self._past | |
| inputs = inputs.to_dict() | |
| inputs['edge_type'] = torch.cat([torch.tensor(inputs['edge_type'][i]) for i in range(len(inputs['edge_type']))], dim=0) | |
| inputs['edge_type'] = torch.argmax(inputs['edge_type'], dim=1) | |
| inputs = {k: v.to(device=self.args.device, non_blocking=True) if hasattr(v, 'to') else v for k, v in inputs.items()} | |
| return inputs | |
| def prediction_step( | |
| self, | |
| model: nn.Module, | |
| inputs: Dict[str, Union[torch.Tensor, Any]], | |
| prediction_loss_only: bool, | |
| ignore_keys: Optional[List[str]] = None, | |
| ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| """ | |
| Perform an evaluation step on `model` using `inputs`. | |
| Subclass and override to inject custom behavior. | |
| Args: | |
| model (`nn.Module`): | |
| The model to evaluate. | |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |
| The inputs and targets of the model. | |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
| argument `labels`. Check your model's documentation for all accepted arguments. | |
| prediction_loss_only (`bool`): | |
| Whether or not to return the loss only. | |
| Return: | |
| Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and | |
| labels (each being optional). | |
| """ | |
| if not self.args.predict_with_generate or prediction_loss_only: | |
| return super().prediction_step( | |
| model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys | |
| ) | |
| has_labels = "labels" in inputs | |
| inputs = self._prepare_inputs(inputs) | |
| # XXX: adapt synced_gpus for fairscale as well | |
| gen_kwargs = self._gen_kwargs.copy() | |
| if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: | |
| gen_kwargs["max_length"] = self.model.config.max_length | |
| gen_kwargs["num_beams"] = ( | |
| gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams | |
| ) | |
| default_synced_gpus = True if is_deepspeed_zero3_enabled() else False | |
| gen_kwargs["synced_gpus"] = ( | |
| gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus | |
| ) | |
| if "attention_mask" in inputs: | |
| gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) | |
| if "global_attention_mask" in inputs: | |
| gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) | |
| generation_inputs = None | |
| gen_kwargs['x'] = inputs.get('x', None) | |
| gen_kwargs['edge_index'] = inputs.get('edge_index', None) | |
| gen_kwargs['edge_type'] = inputs.get('edge_type', None) | |
| gen_kwargs['batch'] = inputs.get('batch', None) | |
| gen_kwargs['encoder_input_ids'] = inputs.get('encoder_input_ids', None) | |
| gen_kwargs['decoder_input_ids'] = inputs.get('decoder_input_ids', None)[:,0:1] | |
| gen_kwargs["decoder_attention_mask"] = torch.ones(gen_kwargs['decoder_input_ids'].shape[0], 1).to(self.args.device) | |
| generated_tokens = self.model.generate( | |
| generation_inputs, | |
| **gen_kwargs, | |
| ) | |
| # in case the batch is shorter than max length, the output should be padded | |
| if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: | |
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) | |
| elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( | |
| gen_kwargs["max_new_tokens"] + 1 | |
| ): | |
| generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) | |
| with torch.no_grad(): | |
| if has_labels: | |
| with self.compute_loss_context_manager(): | |
| outputs = model(**inputs) | |
| if self.label_smoother is not None: | |
| loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() | |
| else: | |
| loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() | |
| else: | |
| loss = None | |
| if self.args.prediction_loss_only: | |
| return (loss, None, None) | |
| if has_labels: | |
| labels = inputs["labels"] | |
| if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: | |
| labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) | |
| elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( | |
| gen_kwargs["max_new_tokens"] + 1 | |
| ): | |
| labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) | |
| else: | |
| labels = None | |
| return (loss, generated_tokens, labels) |