| import os |
| import sys |
| import site |
|
|
| try: |
| cudnn_path = os.path.join(site.getsitepackages()[0], 'nvidia', 'cudnn', 'lib') |
| if os.path.exists(cudnn_path): |
| if 'LD_LIBRARY_PATH' in os.environ: |
| os.environ['LD_LIBRARY_PATH'] = f"{cudnn_path}:{os.environ['LD_LIBRARY_PATH']}" |
| else: |
| os.environ['LD_LIBRARY_PATH'] = cudnn_path |
| if "RESTARTED" not in os.environ: |
| os.environ["RESTARTED"] = "1" |
| os.execv(sys.executable, [sys.executable] + sys.argv) |
| except Exception: |
| pass |
|
|
| import onnxruntime as ort |
|
|
| import tiktoken |
| import numpy as np |
| import time |
|
|
| |
| MODEL_PATH = "Apex_1.5_Coder_DYNAMIC.onnx" |
| VOCAB_SIZE = 50304 |
| enc = tiktoken.get_encoding("gpt2") |
|
|
| |
| options = ort.SessionOptions() |
| options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
| print(f"🚀 Loading Dynamic ONNX Model: {MODEL_PATH}...") |
| providers = [ |
| ('CUDAExecutionProvider', { |
| 'device_id': 0, |
| 'arena_extend_strategy': 'kNextPowerOfTwo', |
| }), |
| 'CPUExecutionProvider' |
| ] |
|
|
| try: |
| session = ort.InferenceSession(MODEL_PATH, sess_options=options, providers=providers) |
| print(f"✅ Active Provider: {session.get_providers()[0]}") |
| except Exception as e: |
| print(f"❌ Error loading model: {e}") |
| sys.exit() |
|
|
| def get_param(prompt, default): |
| """Reads input and returns default if empty.""" |
| val = input(f"{prompt} (Default: {default}): ").strip() |
| if not val: |
| return default |
| return type(default)(val) |
|
|
| def apply_sampling(logits, temperature, top_k, repetition_penalty, history): |
| """ |
| Applies Top-K, Temperature and Repetition Penalty to logits. |
| """ |
| |
| if repetition_penalty != 1.0 and len(history) > 0: |
| unique_tokens = np.unique(history) |
| |
| for token in unique_tokens: |
| if token < len(logits): |
| if logits[token] > 0: |
| logits[token] /= repetition_penalty |
| else: |
| logits[token] *= repetition_penalty |
|
|
| |
| logits = logits / max(temperature, 1e-6) |
|
|
| |
| top_k = min(top_k, len(logits)) |
| indices_to_remove = logits < np.partition(logits, -top_k)[-top_k] |
| logits[indices_to_remove] = -float('Inf') |
|
|
| |
| exp_logits = np.exp(logits - np.max(logits)) |
| probs = exp_logits / np.sum(exp_logits) |
| |
| return int(np.random.choice(len(logits), p=probs)) |
|
|
| def run_chat(): |
| print("\n" + "="*50) |
| print(" APEX 1.5 DYNAMIC ONNX INTERACTIVE CHAT") |
| print("="*50 + "\n") |
|
|
| while True: |
| user_input = input("You: ") |
| if user_input.lower() in ["exit", "quit", "beenden"]: |
| break |
|
|
| |
| temp = get_param(" Temperature", 0.55) |
| tk = get_param(" Top-K", 40) |
| rp = get_param(" Repetition Penalty", 1.2) |
| max_tk = get_param(" Max New Tokens", 500) |
|
|
| |
| prompt = f"Instruction:\n{user_input}\n\nResponse:\n" |
| input_ids = enc.encode(prompt) |
| history = list(input_ids) |
|
|
| print("\nApex 1.5: ", end="", flush=True) |
| |
| start_time = time.time() |
| token_count = 0 |
| last_printed_len = 0 |
| full_response_ids = [] |
|
|
| |
| for _ in range(max_tk): |
| |
| |
| current_ctx = input_ids[-1024:] |
| input_array = np.array([current_ctx], dtype=np.int64) |
|
|
| |
| outputs = session.run(None, {'input': input_array}) |
| |
| |
| |
| logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32) |
|
|
| |
| next_token = apply_sampling(logits, temp, tk, rp, history) |
|
|
| if next_token == enc.eot_token or next_token >= 50257: |
| break |
|
|
| |
| input_ids.append(next_token) |
| full_response_ids.append(next_token) |
| history.append(next_token) |
| token_count += 1 |
|
|
| |
| decoded_text = enc.decode(full_response_ids) |
| new_text = decoded_text[last_printed_len:] |
| |
| |
| if "Instruction:" in new_text: |
| break |
| |
| print(new_text, end="", flush=True) |
| last_printed_len = len(decoded_text) |
|
|
| duration = time.time() - start_time |
| tps = token_count / duration if duration > 0 else 0 |
| |
| print(f"\n\n[Speed: {tps:.2f} tokens/s | Time: {duration:.2f}s]") |
| print("-" * 40 + "\n") |
|
|
| if __name__ == "__main__": |
| run_chat() |