TildeOpen-30b-64k / modeling_llama.py
Martins Kronis
add YaRN patch for transformers 4x
a66e9c6
from packaging import version
import transformers
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
import warnings
_TRANSFORMERS_VERSION = version.parse(transformers.__version__)
print(f"[llama-yarn] Detected transformers version: {_TRANSFORMERS_VERSION}")
if _TRANSFORMERS_VERSION >= version.parse("5.0.0"):
_patch_version = _TRANSFORMERS_VERSION
print(
f"[llama-yarn] Using default transformers implementation, "
f"since transformers version {_patch_version} >= 5.0.0"
)
LlamaForCausalLM = HFLlamaForCausalLM
else:
_patch_version = version.parse("4.46.3")
print(f"[llama-yarn] Using transformers<5 patch (target version {_patch_version})")
from .llama_yarn_patch_4x import LlamaForCausalLMYarn4x as LlamaForCausalLM
if _TRANSFORMERS_VERSION == _patch_version:
print(
f"[llama-yarn] Patch version matches transformers exactly "
f"({_TRANSFORMERS_VERSION})"
)
else:
warnings.warn(
"[llama-yarn] Patch version mismatch:\n"
f" transformers installed: {_TRANSFORMERS_VERSION}\n"
f" patch built for: {_patch_version}\n"
"The model may still work but compatibility is not guaranteed.",
RuntimeWarning,
)