-
Notifications
You must be signed in to change notification settings - Fork 0
Add adapter model for conditioning #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
523d51b
f69100b
0de05a2
16cd998
198c8a5
5f8044a
f9fb4b9
849f449
6a7fcda
8398864
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import argparse | ||
| from os import environ | ||
| from os.path import join | ||
|
|
||
| import torch | ||
| from omegaconf import OmegaConf | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from src.conditioning.adapter import Adapter | ||
| from src.data.convai2_dataset import ConvAI2Dataset | ||
| from src.diffusion.model import DiDi | ||
| from src.pipeline.training import train_model | ||
| from src.utils import filter_warnings, setup_logger, zero_rank_info | ||
|
|
||
|
|
||
| def configure_arg_parser(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("config_path", type=str, help="Path to YAML config file") | ||
| parser.add_argument("dataset_dir", type=str, help="Path to dataset directory") | ||
| parser.add_argument("model_path", type=str, help="Path to DiDi model") | ||
| parser.add_argument("--condition", type=str, default="other", help="Type of persona") | ||
| return parser | ||
|
|
||
|
|
||
| def main(config_path: str, dataset_dir: str, model_path: str, condition: str): | ||
| filter_warnings() | ||
| setup_logger() | ||
| environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
|
||
| torch.set_float32_matmul_precision("high") | ||
|
|
||
| config = OmegaConf.load(config_path) | ||
| zero_rank_info(f"Loaded config:\n{OmegaConf.to_yaml(config, resolve=False, sort_keys=True)}") | ||
|
|
||
| train_dataset = ConvAI2Dataset( | ||
| join(dataset_dir, f"train_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset | ||
| ) | ||
| val_dataset = ConvAI2Dataset( | ||
| join(dataset_dir, f"valid_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset | ||
| ) | ||
|
|
||
| train_dataloader = DataLoader( | ||
| train_dataset, | ||
| batch_size=config.batch_size, | ||
| collate_fn=train_dataset.collate_fn, | ||
| pin_memory=True, | ||
| num_workers=1, | ||
| ) | ||
|
|
||
| val_dataloader = DataLoader( | ||
| val_dataset, | ||
| batch_size=config.val_batch_size, | ||
| collate_fn=val_dataset.collate_fn, | ||
| pin_memory=True, | ||
| num_workers=1, | ||
| ) | ||
|
|
||
| didi = DiDi.load_from_checkpoint(model_path) | ||
| model = Adapter(didi) | ||
|
|
||
| train_model( | ||
| model, | ||
| train_dataloader, | ||
| val_dataloader, | ||
| config.trainer, | ||
| seed=config.seed, | ||
| save_interval=config.save_interval, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| _args = configure_arg_parser().parse_args() | ||
| main(**vars(_args)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| import torch | ||
| from lightning import LightningModule | ||
| from torch import nn | ||
|
|
||
| from src.diffusion.model import DiDi | ||
| from src.diffusion.utils import get_diffusion_variables, get_x0 | ||
| from src.pipeline.utils import calculate_train_step, freeze_params, get_cached_content, get_optimizers | ||
|
|
||
|
|
||
| class AdapterBlock(nn.Module): | ||
| def __init__(self, input_dim: int, num_heads: int): | ||
| super().__init__() | ||
| self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads) | ||
| self.query = nn.Linear(input_dim, input_dim) | ||
| self.key = nn.Linear(input_dim, input_dim) | ||
| self.value = nn.Linear(input_dim, input_dim) | ||
|
|
||
| def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask): | ||
| query = self.query(hidden_states) | ||
| key = self.key(encoder_hidden_states) | ||
| value = self.value(encoder_hidden_states) | ||
| return self.attention(key, query, value, need_weights=False, key_padding_mask=encoder_attention_mask > 0) | ||
|
|
||
|
|
||
| class Adapter(LightningModule): | ||
| def __init__(self, didi: DiDi, lr: float = 0.001, warmup_steps: int = 1, min_lr: float = None): | ||
| super().__init__() | ||
| self.didi = didi | ||
| freeze_params(self.didi) | ||
|
|
||
| self.decoder_layers = [] | ||
| adapter_layers = [] | ||
| for layer in didi.decoder.encoder.layer: | ||
| self.decoder_layers.append(layer) | ||
| adapter_layers.append(AdapterBlock(layer.output.dense.out_features, 1)) | ||
|
|
||
| self.adapter_layers = nn.ModuleList(adapter_layers) | ||
|
|
||
| self.lr, self.warmup, self.min_lr = lr, warmup_steps, min_lr | ||
|
|
||
| def configure_optimizers(self): | ||
| return get_optimizers(self) | ||
|
|
||
| def forward( | ||
| self, | ||
| encoder_input_ids: torch.Tensor = None, | ||
| encoder_attention_mask: torch.Tensor = None, | ||
| decoder_inputs_embeds: torch.Tensor = None, | ||
| condition_input_ids: torch.Tensor = None, | ||
| condition_attention_mask: torch.Tensor = None, | ||
| time_ids: torch.Tensor = None, | ||
| context: torch.Tensor = None, | ||
| condition: torch.Tensor = None, | ||
| ): | ||
| if encoder_input_ids is None and context is None: | ||
| raise ValueError("Either `encoder_input_ids` or `context` must be provided.") | ||
|
|
||
| if condition_input_ids is None and condition is None: | ||
| raise ValueError("Either `condition_input_ids` or `condition` must be provided.") | ||
|
|
||
| context = context or get_cached_content(self.didi, encoder_input_ids, encoder_attention_mask) | ||
| condition = condition or get_cached_content(self.didi, condition_input_ids, condition_attention_mask) | ||
|
|
||
| time_embeds = self.didi.time_embeds(time_ids) | ||
| hidden_states = decoder_inputs_embeds + time_embeds | ||
|
|
||
| for decoder_layer, adapter_layer in zip(self.decoder_layers, self.adapter_layers): | ||
| output = decoder_layer( | ||
| hidden_states=hidden_states, | ||
| encoder_hidden_states=context, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| )[0] | ||
| hidden_states = adapter_layer( | ||
| hidden_states=output, | ||
| encoder_hidden_states=condition, | ||
| encoder_attention_mask=condition_attention_mask, | ||
| )[0] | ||
|
|
||
| return hidden_states, context, condition | ||
|
|
||
| def training_step(self, batch: list, batch_idx: int): | ||
| raw_context, target, condition = batch | ||
| emb = self.didi.emb(target.input_ids) | ||
| x_0 = get_x0(emb, self.didi.std_0) | ||
| noise = torch.randn_like(x_0) | ||
|
|
||
| # x: [batch size; seq len; emb dim], t: [batch size] | ||
| x_t, t = get_diffusion_variables(self.didi.diffusion_steps, x_0, self.didi.sigmas, noise) | ||
|
|
||
| x_0_hat, *_ = self( | ||
| encoder_input_ids=raw_context.input_ids, | ||
| encoder_attention_mask=raw_context.attention_mask, | ||
| decoder_inputs_embeds=x_t, | ||
| time_ids=t, | ||
| condition_input_ids=condition.input_ids, | ||
| condition_attention_mask=condition.attention_mask, | ||
| ) # [batch size; seq len; emb dim] | ||
|
|
||
| loss, metrics = calculate_train_step(self.didi, emb, x_0, x_0_hat, target, t) | ||
| self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False) | ||
| return loss | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| from dataclasses import dataclass | ||
| from enum import Enum | ||
| from typing import Optional | ||
|
|
||
| from loguru import logger | ||
| from torch.utils.data import Dataset | ||
| from tqdm.auto import tqdm | ||
| from transformers import AutoTokenizer | ||
| from src.data.utils import Preprocessor | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -15,25 +17,48 @@ class ConvAI2Dialog: | |
| partner_persona: Optional[list[str]] = None | ||
|
|
||
|
|
||
| class Conditions(Enum): | ||
| NONE = 0 | ||
| YOUR = 1 | ||
| PARTNERS = 2 | ||
|
|
||
|
|
||
| def get_condition(path: str): | ||
| if "none" in path: | ||
| return Conditions.NONE | ||
| elif "self" in path: | ||
| return Conditions.YOUR | ||
| else: | ||
| return Conditions.PARTNERS | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is too complicated. Since we are only collecting personas in the collect function, let's pass the required arguments directly to it. def collate_fn(..., return_my_persona, return_partner_persona)And also let's return dict, e.g., {
"context": self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs)
"candidates": self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs)
} |
||
|
|
||
|
|
||
| class ConvAI2Dataset(Dataset): | ||
| _YOUR_PERSONA_PREFIX = "your persona: " | ||
| _PARTNER_PERSONA_PREFIX = "partner's persona: " | ||
|
|
||
| def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, have_candidates=True): | ||
| def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, max_condition_len=None): | ||
| self.dataset = [] | ||
| self.num_dialogs = 0 | ||
|
|
||
| self.context_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, truncation_side="left") | ||
| self.candidate_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | ||
|
|
||
| bos = self.context_tokenizer.bos_token | ||
| eos = self.context_tokenizer.eos_token | ||
| self.tokenizer_kwargs = { | ||
| "padding": "max_length", | ||
| "truncation": True, | ||
| "return_tensors": "pt", | ||
| "add_special_tokens": False, | ||
| } | ||
| preprocessor = Preprocessor(tokenizer_name) | ||
| bos = preprocessor.bos | ||
| eos = preprocessor.eos | ||
|
|
||
| self.max_context_len = max_context_len | ||
| self.max_target_len = max_target_len or max_context_len | ||
| self.max_condition_len = max_condition_len or max_context_len | ||
|
|
||
| self.have_candidates = have_candidates | ||
| self.have_candidates = not "no_cands" in path | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this should be in the path? We use the same dataset for working with and without candidates |
||
| self.vocab_size = self.context_tokenizer.vocab_size | ||
| self.condition = get_condition(path) | ||
|
|
||
| logger.info(f"Loading dataset from '{path}'") | ||
| with open(path, "r") as f: | ||
|
|
@@ -53,7 +78,7 @@ def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, h | |
| partner_persona.append(line[len(self._PARTNER_PERSONA_PREFIX) :]) | ||
| continue | ||
|
|
||
| if have_candidates: | ||
| if self.have_candidates: | ||
| utterance1, utterance2, _, candidates_str = line.split("\t") | ||
| else: | ||
| utterance1, utterance2, *_ = line.split("\t") | ||
|
|
@@ -79,26 +104,29 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool = | |
| return_all_candidates = self.have_candidates & return_all_candidates | ||
| str_contexts = [" ".join(sample.context) for sample in samples] | ||
| # [batch size, context seq len] | ||
| b_contexts = self.context_tokenizer( | ||
| str_contexts, | ||
| max_length=self.max_context_len, | ||
| padding=True, | ||
| truncation=True, | ||
| return_tensors="pt", | ||
| add_special_tokens=False, | ||
| ).input_ids | ||
| b_contexts = self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs) | ||
|
|
||
| str_conditions = [] | ||
| if self.condition is Conditions.YOUR: | ||
| str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you ignoring the mypy here? What error does it raise?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It raises |
||
| elif self.condition is Conditions.PARTNERS: | ||
| str_conditions = [" ".join(sample.partner_persona) for sample in samples] # type: ignore | ||
|
|
||
| if return_all_candidates: | ||
| str_candidates = [it for sample in samples for it in sample.candidates] | ||
| else: | ||
| str_candidates = [sample.candidates[0] for sample in samples] | ||
|
|
||
| # Tokenizer truncates on the left, but for candidates we want to truncate on the right | ||
| b_candidates = self.candidate_tokenizer( | ||
| str_candidates, padding="max_length", return_tensors="pt", add_special_tokens=False | ||
| ).input_ids | ||
| b_candidates = b_candidates[:, : self.max_target_len] | ||
| # [batch size, # candidates, candidates seq len] | ||
| b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) | ||
|
|
||
| return b_contexts, b_candidates.squeeze(1) | ||
| b_candidates = self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs) | ||
| # b_candidates = b_candidates[:, : self.max_target_len] | ||
| # # [batch size, # candidates, candidates seq len] | ||
| # b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1)) | ||
|
|
||
| if self.condition is Conditions.NONE: | ||
| return b_contexts, b_candidates # .squeeze(1) | ||
| else: | ||
| b_conditions = self.candidate_tokenizer( | ||
| str_conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs | ||
| ) | ||
| return b_contexts, b_candidates, b_conditions # .squeeze(1) | ||
Uh oh!
There was an error while loading. Please reload this page.