| import torch |
| import torch.nn as nn |
|
|
| class OTitansTriArchRouter(nn.Module): |
| """ |
| Phase 3: The Tri-Arch Router. |
| Dynamically routes forward passes between the frozen base model, |
| the Memory OTITANS gate, and potential Skill OTITANS gates. |
| """ |
| def __init__(self, base_model, memory_gate, skill_gate=None): |
| super().__init__() |
| self.base_model = base_model |
| |
| |
| self.memory_gate = memory_gate |
| self.skill_gate = skill_gate |
| |
| |
| |
| self.current_memory_alpha = 1.0 |
| self.current_skill_alpha = 0.0 |
|
|
| def set_routing_alphas(self, memory_alpha: float, skill_alpha: float): |
| """Dynamically adjust the routing gates before a forward pass.""" |
| self.current_memory_alpha = memory_alpha |
| self.current_skill_alpha = skill_alpha |
| |
| |
|
|
| def forward(self, input_ids, **kwargs): |
| |
| |
| base_outputs = self.base_model( |
| input_ids, |
| output_hidden_states=True, |
| return_dict=True, |
| **kwargs |
| ) |
| |
| |
| hidden_states = base_outputs.hidden_states[-1] |
| |
| |
| if self.current_memory_alpha > 0.0 and self.memory_gate is not None: |
| |
| memory_states = self.memory_gate(hidden_states) |
| |
| hidden_states = hidden_states + (memory_states * self.current_memory_alpha) |
| |
| |
| if self.current_skill_alpha > 0.0 and self.skill_gate is not None: |
| skill_states = self.skill_gate(hidden_states) |
| hidden_states = hidden_states + (skill_states * self.current_skill_alpha) |
| |
| |
| |
| logits = self.base_model.lm_head(hidden_states) |
| |
| return logits |
|
|