from packaging import version import transformers from transformers import LlamaForCausalLM as HFLlamaForCausalLM import warnings _TRANSFORMERS_VERSION = version.parse(transformers.__version__) print(f"[llama-yarn] Detected transformers version: {_TRANSFORMERS_VERSION}") if _TRANSFORMERS_VERSION >= version.parse("5.0.0"): _patch_version = _TRANSFORMERS_VERSION print( f"[llama-yarn] Using default transformers implementation, " f"since transformers version {_patch_version} >= 5.0.0" ) LlamaForCausalLM = HFLlamaForCausalLM else: _patch_version = version.parse("4.46.3") print(f"[llama-yarn] Using transformers<5 patch (target version {_patch_version})") from .llama_yarn_patch_4x import LlamaForCausalLMYarn4x as LlamaForCausalLM if _TRANSFORMERS_VERSION == _patch_version: print( f"[llama-yarn] Patch version matches transformers exactly " f"({_TRANSFORMERS_VERSION})" ) else: warnings.warn( "[llama-yarn] Patch version mismatch:\n" f" transformers installed: {_TRANSFORMERS_VERSION}\n" f" patch built for: {_patch_version}\n" "The model may still work but compatibility is not guaranteed.", RuntimeWarning, )