from transformers import AutoTokenizer from ctypes_bindings import * from model_configs import model_configs import threading import time import sys import os MODEL_PATH = "./models" # Create a dict of various model configs, and then check the ./models directory if any exist # This will become the content of the model selector's drop down menu def available_models(): if not os.path.exists(MODEL_PATH): os.mkdir(MODEL_PATH) # Initialize the dict of available models as empty rkllm_model_files = {} # Populate the dictionary with found models, and their base configurations for family, config in model_configs.items(): for model, details in config["models"].items(): filename = details["filename"] if os.path.exists(os.path.join(MODEL_PATH, filename)): rkllm_model_files[model] = {} rkllm_model_files[model].update({"name": model,"family": family, "filename": filename, "config": config["base_config"]}) return rkllm_model_files # Define the callback function def callback_impl(result, userdata, state): global global_text, global_state, split_byte_data if state == LLMCallState.RKLLM_RUN_FINISH: global_state = state print("\n") sys.stdout.flush() elif state == LLMCallState.RKLLM_RUN_ERROR: global_state = state print("run error") sys.stdout.flush() elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER: ''' If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size. With these three parameters, you can retrieve the data from last_hidden_layer. Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback. ''' if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0: data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float) print(f"data_size: {data_size}") global_text.append(f"data_size: {data_size}\n") output_path = os.getcwd() + "/last_hidden_layer.bin" with open(output_path, "wb") as outFile: data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float)) float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float)) float_array = float_array_type.from_address(ctypes.addressof(data.contents)) outFile.write(bytearray(float_array)) print(f"Data saved to {output_path} successfully!") global_text.append(f"Data saved to {output_path} successfully!") else: print("Invalid hidden layer data.") global_text.append("Invalid hidden layer data.") global_state = state time.sleep(0.05) sys.stdout.flush() else: # Save the output token text and the RKLLM running state global_state = state # Monitor if the current byte data is complete; if incomplete, record it for later parsing try: if split_byte_data == None or split_byte_data == "" or split_byte_data == '': global_text.append((b"" + result.contents.text).decode('utf-8')) print((split_byte_data + result.contents.text).decode('utf-8'), end='') split_byte_data = bytes(b"") else: global_text.append((split_byte_data + result.contents.text).decode('utf-8')) print((split_byte_data + result.contents.text).decode('utf-8'), end='') split_byte_data = bytes(b"") except: if result.contents.text is not None: split_byte_data += result.contents.text sys.stdout.flush() # Connect the callback function between the Python side and the C++ side callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int) callback = callback_type(callback_impl) class RKLLMLoaderClass: def __init__(self, model="", qtype="w8a8", opt="1", hybrid_quant="1.0"): self.qtype = qtype self.opt = opt self.model = model self.hybrid_quant = hybrid_quant if self.model == "": print("No models loaded yet!") else: self.available_models = available_models() self.model = self.available_models[model] self.family = self.model["family"] self.model_path = "models/" + self.model["filename"] self.base_config = self.model["config"] self.model_name = self.model["name"] self.st_model_id = self.base_config["st_model_id"] self.system_prompt = self.base_config["system_prompt"] self.rkllm_param = RKLLMParam() self.rkllm_param.model_path = bytes(self.model_path, 'utf-8') self.rkllm_param.max_context_len = self.base_config["max_context_len"] self.rkllm_param.max_new_tokens = self.base_config["max_new_tokens"] self.rkllm_param.skip_special_token = True self.rkllm_param.top_k = self.base_config["top_k"] self.rkllm_param.top_p = self.base_config["top_p"] # self.rkllm_param.min_p = 0.1 self.rkllm_param.temperature = self.base_config["temperature"] self.rkllm_param.repeat_penalty = self.base_config["repeat_penalty"] self.rkllm_param.frequency_penalty = self.base_config["frequency_penalty"] self.rkllm_param.presence_penalty = 0.0 self.rkllm_param.mirostat = 0 self.rkllm_param.mirostat_tau = 5.0 self.rkllm_param.mirostat_eta = 0.1 self.rkllm_param.is_async = False self.rkllm_param.img_start = "".encode('utf-8') self.rkllm_param.img_end = "".encode('utf-8') self.rkllm_param.img_content = "".encode('utf-8') self.rkllm_param.extend_param.base_domain_id = 0 self.handle = RKLLM_Handle_t() self.rkllm_init = rkllm_lib.rkllm_init self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type] self.rkllm_init.restype = ctypes.c_int self.rkllm_init(self.handle, self.rkllm_param, callback) self.rkllm_run = rkllm_lib.rkllm_run self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p] self.rkllm_run.restype = ctypes.c_int self.rkllm_abort = rkllm_lib.rkllm_abort self.rkllm_abort.argtypes = [RKLLM_Handle_t] self.rkllm_abort.restype = ctypes.c_int self.rkllm_destroy = rkllm_lib.rkllm_destroy self.rkllm_destroy.argtypes = [RKLLM_Handle_t] self.rkllm_destroy.restype = ctypes.c_int # Record the user's input prompt def get_user_input(self, user_message, history): history = history + [[user_message, None]] return "", history def tokens_to_ctypes_array(self, tokens, ctype): # Converts a Python list to a ctypes array. # The tokenizer outputs as a Python list. return (ctype * len(tokens))(*tokens) # Run inference def run(self, prompt): self.rkllm_infer_params = RKLLMInferParam() ctypes.memset(ctypes.byref(self.rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam)) self.rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE self.rkllm_input = RKLLMInput() self.rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_TOKEN self.rkllm_input.input_data.token_input.input_ids = self.tokens_to_ctypes_array(prompt, ctypes.c_int) self.rkllm_input.input_data.token_input.n_tokens = ctypes.c_ulong(len(prompt)) self.rkllm_run(self.handle, ctypes.byref(self.rkllm_input), ctypes.byref(self.rkllm_infer_params), None) return # Release RKLLM object from memory def release(self): self.rkllm_abort(self.handle) self.rkllm_destroy(self.handle) # Retrieve the output from the RKLLM model and print it in a streaming manner def get_RKLLM_output(self, message, history): # Link global variables to retrieve the output information from the callback function global global_text, global_state global_text = [] global_state = -1 user_prompt = {"role": "user", "content": message} history.append(user_prompt) # Gemma 2 does not support system prompt. if self.system_prompt == "": prompt = [user_prompt] else: prompt = [ {"role": "system", "content": self.system_prompt}, user_prompt ] # print(prompt) TOKENIZER_PATH="%s/%s"%(MODEL_PATH,self.st_model_id.replace("/","-")) if not os.path.exists(TOKENIZER_PATH): print("Tokenizer not cached locally, downloading to %s"%TOKENIZER_PATH) os.mkdir(TOKENIZER_PATH) tokenizer = AutoTokenizer.from_pretrained(self.st_model_id, trust_remote_code=True) tokenizer.save_pretrained(TOKENIZER_PATH) else: tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) prompt = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True) # response = {"role": "assistant", "content": "Loading..."} response = {"role": "assistant", "content": ""} history.append(response) model_thread = threading.Thread(target=self.run, args=(prompt,)) model_thread.start() model_thread_finished = False while not model_thread_finished: while len(global_text) > 0: response["content"] += global_text.pop(0) # Marco-o1 response["content"] = str(response["content"]).replace("", "\\") response["content"] = str(response["content"]).replace("", "\\<\\/Thought\\>") response["content"] = str(response["content"]).replace("", "\\") response["content"] = str(response["content"]).replace("", "\\<\\/Output\\>") time.sleep(0.005) # Gradio automatically pushes the result returned by the yield statement when calling the then method yield response model_thread.join(timeout=0.005) model_thread_finished = not model_thread.is_alive()