diff --git a/data/shakespeare/prepare.py b/data/shakespeare/prepare.py index 06573ac..71c88da 100644 --- a/data/shakespeare/prepare.py +++ b/data/shakespeare/prepare.py @@ -4,12 +4,13 @@ import tiktoken import numpy as np # download the tiny shakespeare dataset -if not os.path.exists('input.txt'): +input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') +if not os.path.exists(input_file_path): data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' - with open('input.txt', 'w') as f: + with open(input_file_path, 'w') as f: f.write(requests.get(data_url).text) -with open('input.txt', 'r') as f: +with open(input_file_path, 'r') as f: data = f.read() n = len(data) train_data = data[:int(n*0.9)]