카테고리 없음
[Stt] transformer stt 로컬 모델 로드
3pie
2024. 3. 4. 11:41
이 모델은 자동으로 다운받아지고 모델아이디로 로드 하는 방법을 기본적으로 알려주기때문에 모델을 로컬에서 저장하고 갖고 있는 방법에 대한 따로 명시해놓은게 없었다. 그래서 내부 코드를 보고 직접 찾아 냈는데,
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
위 코드로 다운받게 되면 저장 위치가
/root/.cache/huggingface/hub/models--openai--whisper-large-v3
여기에 실제 모델이 safetensorns 형태로 저장되게 된다.
하지만 이 형태로 사용하게 되면 로컬 형태로 사용이 안되는데 그래서
transformers/utils/hub.py
cached_file 이랑 함수를 보면 아래와 같은데 model_id를 사용해서 /root 위치에 있는걸 활용하는게 아닌 local 위치에서 사용하도록 하는 방법 이다.
resolved_file 이란 config.json 파일이다.
cached_file 함수의 일부분인데 이렇게 디렉토리 path를 파라미터로 전달하면 다운받지 않고 로드를 할수 있게 된다.
이렇게 모델 호출을 하면되는데
model = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path = './whisper_transformer/models--openai--whisper-large-v3/snapshots',
subfolder='1ecca609f9a5ae2cd97~~~',filename='config.json', torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
)
현재 버전으로는 키워드 에러가 나는 부분이 생겨서
transformers/modeling_utils.py
에서 from_pretrained 함수에서 에러가 난다. 파일네임을 전달해줘서 그런건데 지워주면 된다.