LH-Tech-AI commited on
Commit
d30d68c
·
verified ·
1 Parent(s): 430eb1c

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +157 -0
inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import site
4
+
5
+ try:
6
+ cudnn_path = os.path.join(site.getsitepackages()[0], 'nvidia', 'cudnn', 'lib')
7
+ if os.path.exists(cudnn_path):
8
+ if 'LD_LIBRARY_PATH' in os.environ:
9
+ os.environ['LD_LIBRARY_PATH'] = f"{cudnn_path}:{os.environ['LD_LIBRARY_PATH']}"
10
+ else:
11
+ os.environ['LD_LIBRARY_PATH'] = cudnn_path
12
+ if "RESTARTED" not in os.environ:
13
+ os.environ["RESTARTED"] = "1"
14
+ os.execv(sys.executable, [sys.executable] + sys.argv)
15
+ except Exception:
16
+ pass
17
+
18
+ import onnxruntime as ort
19
+
20
+ import tiktoken
21
+ import numpy as np
22
+ import time
23
+
24
+ # --- Configuration ---
25
+ MODEL_PATH = "Apex_1.5_Coder_DYNAMIC.onnx"
26
+ VOCAB_SIZE = 50304
27
+ enc = tiktoken.get_encoding("gpt2")
28
+
29
+ # Setup ONNX Session with CUDA
30
+ options = ort.SessionOptions()
31
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
32
+
33
+ print(f"🚀 Loading Dynamic ONNX Model: {MODEL_PATH}...")
34
+ providers = [
35
+ ('CUDAExecutionProvider', {
36
+ 'device_id': 0,
37
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
38
+ }),
39
+ 'CPUExecutionProvider'
40
+ ]
41
+
42
+ try:
43
+ session = ort.InferenceSession(MODEL_PATH, sess_options=options, providers=providers)
44
+ print(f"✅ Active Provider: {session.get_providers()[0]}")
45
+ except Exception as e:
46
+ print(f"❌ Error loading model: {e}")
47
+ sys.exit()
48
+
49
+ def get_param(prompt, default):
50
+ """Reads input and returns default if empty."""
51
+ val = input(f"{prompt} (Default: {default}): ").strip()
52
+ if not val:
53
+ return default
54
+ return type(default)(val)
55
+
56
+ def apply_sampling(logits, temperature, top_k, repetition_penalty, history):
57
+ """
58
+ Applies Top-K, Temperature and Repetition Penalty to logits.
59
+ """
60
+ # 1. Repetition Penalty
61
+ if repetition_penalty != 1.0 and len(history) > 0:
62
+ unique_tokens = np.unique(history)
63
+ # Apply penalty: divide positive logits, multiply negative ones
64
+ for token in unique_tokens:
65
+ if token < len(logits):
66
+ if logits[token] > 0:
67
+ logits[token] /= repetition_penalty
68
+ else:
69
+ logits[token] *= repetition_penalty
70
+
71
+ # 2. Temperature Scaling
72
+ logits = logits / max(temperature, 1e-6)
73
+
74
+ # 3. Top-K Sampling
75
+ top_k = min(top_k, len(logits))
76
+ indices_to_remove = logits < np.partition(logits, -top_k)[-top_k]
77
+ logits[indices_to_remove] = -float('Inf')
78
+
79
+ # 4. Softmax and Random Choice
80
+ exp_logits = np.exp(logits - np.max(logits))
81
+ probs = exp_logits / np.sum(exp_logits)
82
+
83
+ return int(np.random.choice(len(logits), p=probs))
84
+
85
+ def run_chat():
86
+ print("\n" + "="*50)
87
+ print(" APEX 1.5 DYNAMIC ONNX INTERACTIVE CHAT")
88
+ print("="*50 + "\n")
89
+
90
+ while True:
91
+ user_input = input("You: ")
92
+ if user_input.lower() in ["exit", "quit", "beenden"]:
93
+ break
94
+
95
+ # Prompt Parameters
96
+ temp = get_param(" Temperature", 0.55)
97
+ tk = get_param(" Top-K", 40)
98
+ rp = get_param(" Repetition Penalty", 1.2)
99
+ max_tk = get_param(" Max New Tokens", 500)
100
+
101
+ # Tokenize and Setup
102
+ prompt = f"Instruction:\n{user_input}\n\nResponse:\n"
103
+ input_ids = enc.encode(prompt)
104
+ history = list(input_ids)
105
+
106
+ print("\nApex 1.5: ", end="", flush=True)
107
+
108
+ start_time = time.time()
109
+ token_count = 0
110
+ last_printed_len = 0
111
+ full_response_ids = []
112
+
113
+ # Generation Loop
114
+ for _ in range(max_tk):
115
+ # Dynamic Input Shape (1, Sequence_Length)
116
+ # We take the last 1024 tokens if it grows too long
117
+ current_ctx = input_ids[-1024:]
118
+ input_array = np.array([current_ctx], dtype=np.int64)
119
+
120
+ # Run ONNX Inference
121
+ outputs = session.run(None, {'input': input_array})
122
+
123
+ # Extract Logits for the last token [Batch, Seq, Vocab]
124
+ # Since it's dynamic, we grab index -1
125
+ logits = outputs[0][0, -1, :VOCAB_SIZE].astype(np.float32)
126
+
127
+ # Sampling Logic
128
+ next_token = apply_sampling(logits, temp, tk, rp, history)
129
+
130
+ if next_token == enc.eot_token or next_token >= 50257:
131
+ break
132
+
133
+ # Update state
134
+ input_ids.append(next_token)
135
+ full_response_ids.append(next_token)
136
+ history.append(next_token)
137
+ token_count += 1
138
+
139
+ # Decode and Print
140
+ decoded_text = enc.decode(full_response_ids)
141
+ new_text = decoded_text[last_printed_len:]
142
+
143
+ # Simple Stop Condition
144
+ if "Instruction:" in new_text:
145
+ break
146
+
147
+ print(new_text, end="", flush=True)
148
+ last_printed_len = len(decoded_text)
149
+
150
+ duration = time.time() - start_time
151
+ tps = token_count / duration if duration > 0 else 0
152
+
153
+ print(f"\n\n[Speed: {tps:.2f} tokens/s | Time: {duration:.2f}s]")
154
+ print("-" * 40 + "\n")
155
+
156
+ if __name__ == "__main__":
157
+ run_chat()